EDEN / modeling_eden.py
Rybib's picture
Upload EDEN model and code
453c542 verified
Raw
History Blame Contribute Delete
15.5 kB
"""EDEN model for Hugging Face Transformers.
EDEN (Encoder Decoder Enhancement Network) is a from-scratch encoder-decoder
Transformer that rewrites rough text into polished text. This module wraps the
original architecture in a ``PreTrainedModel`` so the model can be loaded with
``AutoModel.from_pretrained(..., trust_remote_code=True)`` and saved with
``save_pretrained`` as safetensors.
The layer structure (embedding, positional encoding, ``nn.Transformer``, tied
language-model head) matches the original training code exactly, so checkpoints
trained with the standalone trainer load here without any key remapping beyond
loading the model weights.
"""
from __future__ import annotations
import math
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
# When this file is loaded from the Hugging Face Hub it lives inside a package,
# so the sibling config is a relative import. When imported directly by a local
# script (for example the conversion script) it is a top-level module instead.
# The try/except supports both, and Transformers ignores imports inside a try
# block when checking dependencies.
try:
from .configuration_eden import EdenConfig
except ImportError:
from configuration_eden import EdenConfig
def _normalise_text(text: str) -> str:
text = str(text or "")
text = text.replace("‘", "'").replace("’", "'")
text = text.replace("“", '"').replace("”", '"')
text = re.sub(r"\s+", " ", text).strip()
return text
def _sentence_split(text: str) -> list[str]:
parts = re.split(r"(?<=[.!?])\s+", text.strip())
return [p for p in parts if p]
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int, dropout: float):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Persistent so the table is written to safetensors and restored on load.
# Transformers initialises models on the meta device, which would leave a
# non-persistent buffer uninitialised (NaN) after from_pretrained.
self.register_buffer("pe", pe.unsqueeze(0), persistent=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe[:, : x.size(1), :].to(dtype=x.dtype)
return self.dropout(x)
class EdenPreTrainedModel(PreTrainedModel):
config_class = EdenConfig
base_model_prefix = "eden"
supports_gradient_checkpointing = False
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
with torch.no_grad():
module.weight[module.padding_idx].zero_()
class EdenForTextEnhancement(EdenPreTrainedModel):
"""Encoder-decoder Transformer with a tied language-model head."""
_tied_weights_keys = {"lm_head.weight": "embedding.weight"}
def __init__(self, config: EdenConfig):
super().__init__(config)
self.pad_id = config.pad_token_id
self.bos_id = config.bos_token_id
self.eos_id = config.eos_token_id
self.unk_id = config.unk_token_id
self.scale = math.sqrt(config.d_model)
self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
self.pos = PositionalEncoding(config.d_model, config.max_len + 4, config.dropout)
self.transformer = nn.Transformer(
d_model=config.d_model,
nhead=config.n_heads,
num_encoder_layers=config.n_layers,
num_decoder_layers=config.n_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
# ------------------------------------------------------------------ #
# Hugging Face plumbing
# ------------------------------------------------------------------ #
def get_input_embeddings(self) -> nn.Module:
return self.embedding
def set_input_embeddings(self, value: nn.Module) -> None:
self.embedding = value
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
self.lm_head = new_embeddings
def _tie_weights(self) -> None:
self.lm_head.weight = self.embedding.weight
# ------------------------------------------------------------------ #
# Core encoder-decoder
# ------------------------------------------------------------------ #
def encode(self, src: torch.Tensor):
src_padding = src.eq(self.pad_id)
src_emb = self.pos(self.embedding(src) * self.scale)
memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding)
return memory, src_padding
def decode(self, tgt: torch.Tensor, memory: torch.Tensor, src_padding: torch.Tensor) -> torch.Tensor:
tgt_padding = tgt.eq(self.pad_id)
tgt_emb = self.pos(self.embedding(tgt) * self.scale)
tgt_len = tgt.size(1)
causal_mask = torch.triu(
torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=tgt.device),
diagonal=1,
)
return self.transformer.decoder(
tgt_emb,
memory,
tgt_mask=causal_mask,
tgt_key_padding_mask=tgt_padding,
memory_key_padding_mask=src_padding,
)
def forward(
self,
input_ids: torch.Tensor,
decoder_input_ids: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
return_dict: bool | None = None,
**kwargs,
) -> Seq2SeqLMOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if decoder_input_ids is None and labels is not None:
decoder_input_ids = self._shift_right(labels)
if decoder_input_ids is None:
raise ValueError("Provide decoder_input_ids or labels to EdenForTextEnhancement.forward.")
memory, src_padding = self.encode(input_ids)
hidden = self.decode(decoder_input_ids, memory, src_padding)
logits = self.lm_head(hidden)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.float().reshape(-1, logits.size(-1)),
labels.reshape(-1),
ignore_index=-100,
)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(loss=loss, logits=logits)
def _shift_right(self, labels: torch.Tensor) -> torch.Tensor:
shifted = labels.new_zeros(labels.shape)
shifted[:, 1:] = labels[:, :-1].clone()
shifted[:, 0] = self.bos_id
shifted.masked_fill_(shifted == -100, self.pad_id)
return shifted
# ------------------------------------------------------------------ #
# Generation (ported from the original trainer, no KV cache needed for
# the short sequences this model handles)
# ------------------------------------------------------------------ #
@torch.no_grad()
def _beam_generate(self, src, beam_size, max_new_tokens, length_penalty, repetition_penalty):
self.eval()
device = src.device
max_len = self.config.max_len
memory, src_padding = self.encode(src)
beams = [([self.bos_id], 0.0, False)]
for _ in range(max_new_tokens):
if all(done for _, _, done in beams):
break
candidates = []
for tokens, score, done in beams:
if done:
candidates.append((tokens, score, True))
continue
tgt = torch.tensor([tokens[-max_len:]], dtype=torch.long, device=device)
hidden = self.decode(tgt, memory, src_padding)
logits = self.lm_head(hidden[:, -1, :]).float().squeeze(0)
if repetition_penalty != 1.0:
for token_id in set(tokens):
if 0 <= token_id < logits.numel():
logits[token_id] /= repetition_penalty
logits[self.unk_id] = -float("inf")
logits[self.pad_id] = -float("inf")
logits[self.bos_id] = -float("inf")
log_probs = F.log_softmax(logits, dim=-1)
values, indices = torch.topk(log_probs, k=min(beam_size, log_probs.numel()))
for value, index in zip(values.tolist(), indices.tolist()):
new_tokens = tokens + [int(index)]
candidates.append((new_tokens, score + float(value), int(index) == self.eos_id))
def rank(item):
toks, sc, _ = item
length = max(1, len(toks) - 1)
return sc / (length ** length_penalty)
candidates.sort(key=rank, reverse=True)
beams = candidates[:beam_size]
best = max(beams, key=lambda item: item[1] / (max(1, len(item[0]) - 1) ** length_penalty))
out = best[0][1:]
if self.eos_id in out:
out = out[: out.index(self.eos_id)]
skip = {self.pad_id, self.bos_id, self.eos_id, self.unk_id}
return [t for t in out if t not in skip]
@torch.no_grad()
def _sample_generate(self, src, strategy, max_new_tokens, temperature, top_k, top_p, repetition_penalty):
self.eval()
device = src.device
max_len = self.config.max_len
memory, src_padding = self.encode(src)
tokens = [self.bos_id]
skip = {self.pad_id, self.bos_id, self.eos_id, self.unk_id}
for _ in range(max_new_tokens):
tgt = torch.tensor([tokens[-max_len:]], dtype=torch.long, device=device)
hidden = self.decode(tgt, memory, src_padding)
logits = self.lm_head(hidden[:, -1, :]).float().squeeze(0)
if repetition_penalty != 1.0:
for token_id in set(tokens):
if 0 <= token_id < logits.numel():
logits[token_id] /= repetition_penalty
logits[self.unk_id] = -float("inf")
logits[self.pad_id] = -float("inf")
logits[self.bos_id] = -float("inf")
if strategy == "sample":
logits = logits / max(0.05, temperature)
logits = self._filter_top_k_top_p(logits, top_k, top_p)
probs = F.softmax(logits, dim=-1)
if not torch.isfinite(probs).all() or float(probs.sum().item()) <= 0:
next_id = int(torch.argmax(logits).item())
else:
next_id = int(torch.multinomial(probs.detach().cpu(), 1).item())
else:
next_id = int(torch.argmax(logits).item())
if next_id == self.eos_id:
break
if next_id not in skip:
tokens.append(next_id)
return tokens[1:]
@staticmethod
def _filter_top_k_top_p(logits, top_k, top_p):
filtered = logits.clone()
if top_k > 0 and top_k < filtered.numel():
threshold = torch.topk(filtered, top_k).values[-1]
filtered[filtered < threshold] = -float("inf")
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(filtered, descending=True)
probs = F.softmax(sorted_logits, dim=-1)
cumulative = torch.cumsum(probs, dim=-1)
remove = cumulative > top_p
remove[1:] = remove[:-1].clone()
remove[0] = False
filtered[sorted_indices[remove]] = -float("inf")
return filtered
def _chunk_text(self, text, tokenizer):
text = _normalise_text(text)
ids = tokenizer.encode(text, add_special_tokens=False)
max_src = self.config.max_len - 2
if len(ids) <= max_src:
return [text]
chunks, current, current_ids = [], [], []
for sent in _sentence_split(text) or [text]:
sent_ids = tokenizer.encode(sent, add_special_tokens=False)
if current and len(current_ids) + len(sent_ids) > max_src:
chunks.append(" ".join(current))
current, current_ids = [], []
if len(sent_ids) > max_src:
for i in range(0, len(sent_ids), max_src):
chunks.append(tokenizer.decode(sent_ids[i : i + max_src], skip_special_tokens=True))
else:
current.append(sent)
current_ids.extend(sent_ids)
if current:
chunks.append(" ".join(current))
return chunks
@torch.no_grad()
def enhance(
self,
tokenizer,
text: str,
strategy: str = "beam",
beam_size: int | None = None,
max_new_tokens: int | None = None,
temperature: float = 0.7,
top_k: int = 40,
top_p: float = 0.9,
length_penalty: float | None = None,
repetition_penalty: float | None = None,
) -> str:
"""Rewrite ``text`` into polished text.
This mirrors the original trainer's enhancement pipeline: long inputs are
split into sentence-aware chunks, each chunk is rewritten, and the
results are joined back together.
"""
strategy = strategy if strategy in {"beam", "greedy", "sample"} else "beam"
beam = max(1, int(beam_size or self.config.beam_size))
cap = max(8, self.config.max_len - 1)
max_tokens = int(max_new_tokens) if max_new_tokens else min(256, cap)
max_tokens = max(8, min(cap, max_tokens))
len_penalty = self.config.length_penalty if length_penalty is None else float(length_penalty)
rep_penalty = self.config.repetition_penalty if repetition_penalty is None else float(repetition_penalty)
device = self.embedding.weight.device
outputs = []
for chunk in self._chunk_text(text, tokenizer):
src_ids = tokenizer.encode(chunk, add_special_tokens=False)[: self.config.max_len - 2]
src = torch.tensor([[self.bos_id] + src_ids + [self.eos_id]], dtype=torch.long, device=device)
if strategy == "beam":
out_ids = self._beam_generate(src, beam, max_tokens, len_penalty, rep_penalty)
else:
out_ids = self._sample_generate(
src, strategy, max_tokens, temperature, top_k, top_p, rep_penalty
)
decoded = _normalise_text(tokenizer.decode(out_ids, skip_special_tokens=True))
outputs.append(decoded or chunk)
return _normalise_text(" ".join(outputs))