RoFormer-slav / aeneas_decode.py
MaximEremeev's picture
Add RoFormer-slav
dbee6a0 verified
"""
aeneas_decode.py
~~~~~~~~~~~~~~~~
Confidence-ordered beam search, following Assael et al. (2022) «Ithaca/Aeneas».
Algorithm per step:
1. Forward pass — получаем логиты для ВСЕХ оставшихся [MASK] позиций сразу.
2. Находим позицию с максимальной уверенностью модели
(argmax по max-probability среди всех масок в этом биме).
3. Расширяем только эту позицию: берём top-k токенов.
4. Обрезаем до top-k бимов по суммарному log-probability.
5. Повторяем пока не останется ни одной маски.
"""
from pathlib import Path
import math
from dataclasses import dataclass, field
from typing import List, Optional
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
@dataclass
class Beam:
input_ids: torch.Tensor # [seq_len]
log_prob: float = 0.0
# Список (position, token_id) в порядке заполнения
filled: List[tuple] = field(default_factory=list)
def aeneas_beam_search(
input_ids: torch.Tensor, # [seq_len], уже на нужном device
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
*,
beam_width: int = 5,
temperature: float = 1.0,
banned_token_ids: Optional[List[int]] = None, # e.g. [GAP_ID]
) -> List[Beam]:
"""
Возвращает `beam_width` бимов, отсортированных по убыванию log-probability.
"""
device = input_ids.device
mask_id = tokenizer.mask_token_id
banned = set(banned_token_ids or [])
# Инициализируем один начальный бим
beams: List[Beam] = [Beam(input_ids=input_ids.clone())]
# Считаем сколько масок нужно заполнить
n_masks = (input_ids == mask_id).sum().item()
with torch.no_grad():
for _ in range(n_masks):
candidates: List[Beam] = []
for beam in beams:
mask_positions = (beam.input_ids == mask_id).nonzero(
as_tuple=True)[0].tolist()
if not mask_positions:
candidates.append(beam)
continue
# ── Forward pass ───────────────────────────────────────────
logits = model(
beam.input_ids.unsqueeze(0)
).logits[0] # [seq_len, vocab]
# ── Находим самую уверенную позицию ───────────────────────
# Для каждой маски берём вероятность наиболее вероятного токена.
best_pos = None
best_conf = -1.0
for pos in mask_positions:
pos_logits = logits[pos] / max(temperature, 1e-6)
if banned:
pos_logits = pos_logits.clone()
for tid in banned:
if tid < pos_logits.shape[-1]:
pos_logits[tid] = float("-inf")
max_prob = pos_logits.softmax(dim=-1).max().item()
if max_prob > best_conf:
best_conf = max_prob
best_pos = pos
# ── Расширяем именно эту позицию ──────────────────────────
pos_logits = logits[best_pos] / max(temperature, 1e-6)
if banned:
pos_logits = pos_logits.clone()
for tid in banned:
if tid < pos_logits.shape[-1]:
pos_logits[tid] = float("-inf")
probs = pos_logits.softmax(dim=-1)
top_probs, top_ids = probs.topk(beam_width)
for prob, token_id in zip(top_probs.tolist(), top_ids.tolist()):
if prob <= 0:
continue
new_ids = beam.input_ids.clone()
new_ids[best_pos] = token_id
candidates.append(Beam(
input_ids = new_ids,
log_prob = beam.log_prob + math.log(prob + 1e-12),
filled = beam.filled + [(best_pos, token_id)],
))
# ── Pruning: оставляем top-beam_width бимов ───────────────────
beams = sorted(candidates, key=lambda b: b.log_prob, reverse=True)
beams = beams[:beam_width]
return beams
# ── Вспомогательная функция: декодирование результатов ─────────────────────────
def decode_beams(
beams: List[Beam],
original_ids: torch.Tensor,
tokenizer: PreTrainedTokenizerBase,
) -> List[dict]:
"""
Превращает бимы в читаемый список словарей.
Возвращает:
[
{
"text": полностью восстановленный текст,
"filled_tokens": [(position, token_str), ...] в порядке заполнения,
"score": нормализованная вероятность (0..1),
"log_prob": суммарный log-prob,
},
...
]
"""
results = []
# Нормализуем вероятности через softmax по log_prob бимов
log_probs = torch.tensor([b.log_prob for b in beams], dtype=torch.float)
scores = log_probs.softmax(dim=0).tolist()
for beam, score in zip(beams, scores):
text = tokenizer.decode(beam.input_ids, skip_special_tokens=True)
filled_tokens = [
(pos, tokenizer.decode([tid], skip_special_tokens=True,
clean_up_tokenization_spaces=False).strip())
for pos, tid in beam.filled
]
results.append({
"text": text,
"filled_tokens": filled_tokens,
"score": round(score, 4),
"log_prob": round(beam.log_prob, 4),
})
return results
# ── Высокоуровневый интерфейс ───────────────────────────────────────────────────
def restore(
text: str,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
*,
beam_width: int = 5,
temperature: float = 1.0,
gap_token: str = "[GAP]",
max_length: int = 512,
) -> List[dict]:
"""
Высокоуровневая обёртка: принимает строку с [MASK], возвращает список бимов.
Args:
text: текст с одним или несколькими [MASK] токенами.
gap_token: токен пропуска — исключается из предсказаний.
beam_width: число бимов.
temperature: <1 делает распределение острее, >1 — мягче.
"""
device = next(model.parameters()).device
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
)
input_ids = enc["input_ids"][0].to(device)
# Исключаем [GAP] из предсказаний
banned = []
if gap_token in tokenizer.get_vocab():
banned.append(tokenizer.convert_tokens_to_ids(gap_token))
beams = aeneas_beam_search(
input_ids, model, tokenizer,
beam_width=beam_width,
temperature=temperature,
banned_token_ids=banned,
)
return decode_beams(beams, input_ids, tokenizer)
# ── CLI / быстрая проверка ─────────────────────────────────────────────────────
if __name__ == "__main__":
import argparse
from transformers import AutoModelForMaskedLM, AutoTokenizer
parser = argparse.ArgumentParser()
_HERE = Path(__file__).resolve().parent
parser.add_argument("--model", default=str(_HERE / "outputs/final_model"))
parser.add_argument("--text", default="поклоне ѿ [MASK] к ѥва про [MASK] ѡкупи")
parser.add_argument("--top_k", type=int, default=5)
parser.add_argument("--temp", type=float, default=1.0)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForMaskedLM.from_pretrained(args.model).to(device)
model.eval()
print(f"\nВход: {args.text}\n")
results = restore(args.text, model, tokenizer,
beam_width=args.top_k, temperature=args.temp)
for i, r in enumerate(results, 1):
print(f" [{i}] score={r['score']:.3f} log_prob={r['log_prob']:.3f}")
print(f" {r['text']}")
print(f" заполнено: {r['filled_tokens']}")