""" aeneas_decode.py ~~~~~~~~~~~~~~~~ Confidence-ordered beam search, following Assael et al. (2022) «Ithaca/Aeneas». Algorithm per step: 1. Forward pass — получаем логиты для ВСЕХ оставшихся [MASK] позиций сразу. 2. Находим позицию с максимальной уверенностью модели (argmax по max-probability среди всех масок в этом биме). 3. Расширяем только эту позицию: берём top-k токенов. 4. Обрезаем до top-k бимов по суммарному log-probability. 5. Повторяем пока не останется ни одной маски. """ from pathlib import Path import math from dataclasses import dataclass, field from typing import List, Optional import torch from transformers import PreTrainedModel, PreTrainedTokenizerBase @dataclass class Beam: input_ids: torch.Tensor # [seq_len] log_prob: float = 0.0 # Список (position, token_id) в порядке заполнения filled: List[tuple] = field(default_factory=list) def aeneas_beam_search( input_ids: torch.Tensor, # [seq_len], уже на нужном device model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, *, beam_width: int = 5, temperature: float = 1.0, banned_token_ids: Optional[List[int]] = None, # e.g. [GAP_ID] ) -> List[Beam]: """ Возвращает `beam_width` бимов, отсортированных по убыванию log-probability. """ device = input_ids.device mask_id = tokenizer.mask_token_id banned = set(banned_token_ids or []) # Инициализируем один начальный бим beams: List[Beam] = [Beam(input_ids=input_ids.clone())] # Считаем сколько масок нужно заполнить n_masks = (input_ids == mask_id).sum().item() with torch.no_grad(): for _ in range(n_masks): candidates: List[Beam] = [] for beam in beams: mask_positions = (beam.input_ids == mask_id).nonzero( as_tuple=True)[0].tolist() if not mask_positions: candidates.append(beam) continue # ── Forward pass ─────────────────────────────────────────── logits = model( beam.input_ids.unsqueeze(0) ).logits[0] # [seq_len, vocab] # ── Находим самую уверенную позицию ─────────────────────── # Для каждой маски берём вероятность наиболее вероятного токена. best_pos = None best_conf = -1.0 for pos in mask_positions: pos_logits = logits[pos] / max(temperature, 1e-6) if banned: pos_logits = pos_logits.clone() for tid in banned: if tid < pos_logits.shape[-1]: pos_logits[tid] = float("-inf") max_prob = pos_logits.softmax(dim=-1).max().item() if max_prob > best_conf: best_conf = max_prob best_pos = pos # ── Расширяем именно эту позицию ────────────────────────── pos_logits = logits[best_pos] / max(temperature, 1e-6) if banned: pos_logits = pos_logits.clone() for tid in banned: if tid < pos_logits.shape[-1]: pos_logits[tid] = float("-inf") probs = pos_logits.softmax(dim=-1) top_probs, top_ids = probs.topk(beam_width) for prob, token_id in zip(top_probs.tolist(), top_ids.tolist()): if prob <= 0: continue new_ids = beam.input_ids.clone() new_ids[best_pos] = token_id candidates.append(Beam( input_ids = new_ids, log_prob = beam.log_prob + math.log(prob + 1e-12), filled = beam.filled + [(best_pos, token_id)], )) # ── Pruning: оставляем top-beam_width бимов ─────────────────── beams = sorted(candidates, key=lambda b: b.log_prob, reverse=True) beams = beams[:beam_width] return beams # ── Вспомогательная функция: декодирование результатов ───────────────────────── def decode_beams( beams: List[Beam], original_ids: torch.Tensor, tokenizer: PreTrainedTokenizerBase, ) -> List[dict]: """ Превращает бимы в читаемый список словарей. Возвращает: [ { "text": полностью восстановленный текст, "filled_tokens": [(position, token_str), ...] в порядке заполнения, "score": нормализованная вероятность (0..1), "log_prob": суммарный log-prob, }, ... ] """ results = [] # Нормализуем вероятности через softmax по log_prob бимов log_probs = torch.tensor([b.log_prob for b in beams], dtype=torch.float) scores = log_probs.softmax(dim=0).tolist() for beam, score in zip(beams, scores): text = tokenizer.decode(beam.input_ids, skip_special_tokens=True) filled_tokens = [ (pos, tokenizer.decode([tid], skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()) for pos, tid in beam.filled ] results.append({ "text": text, "filled_tokens": filled_tokens, "score": round(score, 4), "log_prob": round(beam.log_prob, 4), }) return results # ── Высокоуровневый интерфейс ─────────────────────────────────────────────────── def restore( text: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, *, beam_width: int = 5, temperature: float = 1.0, gap_token: str = "[GAP]", max_length: int = 512, ) -> List[dict]: """ Высокоуровневая обёртка: принимает строку с [MASK], возвращает список бимов. Args: text: текст с одним или несколькими [MASK] токенами. gap_token: токен пропуска — исключается из предсказаний. beam_width: число бимов. temperature: <1 делает распределение острее, >1 — мягче. """ device = next(model.parameters()).device enc = tokenizer( text, return_tensors="pt", truncation=True, max_length=max_length, ) input_ids = enc["input_ids"][0].to(device) # Исключаем [GAP] из предсказаний banned = [] if gap_token in tokenizer.get_vocab(): banned.append(tokenizer.convert_tokens_to_ids(gap_token)) beams = aeneas_beam_search( input_ids, model, tokenizer, beam_width=beam_width, temperature=temperature, banned_token_ids=banned, ) return decode_beams(beams, input_ids, tokenizer) # ── CLI / быстрая проверка ───────────────────────────────────────────────────── if __name__ == "__main__": import argparse from transformers import AutoModelForMaskedLM, AutoTokenizer parser = argparse.ArgumentParser() _HERE = Path(__file__).resolve().parent parser.add_argument("--model", default=str(_HERE / "outputs/final_model")) parser.add_argument("--text", default="поклоне ѿ [MASK] к ѥва про [MASK] ѡкупи") parser.add_argument("--top_k", type=int, default=5) parser.add_argument("--temp", type=float, default=1.0) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(args.model) model = AutoModelForMaskedLM.from_pretrained(args.model).to(device) model.eval() print(f"\nВход: {args.text}\n") results = restore(args.text, model, tokenizer, beam_width=args.top_k, temperature=args.temp) for i, r in enumerate(results, 1): print(f" [{i}] score={r['score']:.3f} log_prob={r['log_prob']:.3f}") print(f" {r['text']}") print(f" заполнено: {r['filled_tokens']}")