| """ |
| aeneas_decode.py |
| ~~~~~~~~~~~~~~~~ |
| Confidence-ordered beam search, following Assael et al. (2022) «Ithaca/Aeneas». |
| |
| Algorithm per step: |
| 1. Forward pass — получаем логиты для ВСЕХ оставшихся [MASK] позиций сразу. |
| 2. Находим позицию с максимальной уверенностью модели |
| (argmax по max-probability среди всех масок в этом биме). |
| 3. Расширяем только эту позицию: берём top-k токенов. |
| 4. Обрезаем до top-k бимов по суммарному log-probability. |
| 5. Повторяем пока не останется ни одной маски. |
| """ |
|
|
| from pathlib import Path |
| import math |
| from dataclasses import dataclass, field |
| from typing import List, Optional |
|
|
| import torch |
| from transformers import PreTrainedModel, PreTrainedTokenizerBase |
|
|
|
|
| @dataclass |
| class Beam: |
| input_ids: torch.Tensor |
| log_prob: float = 0.0 |
| |
| filled: List[tuple] = field(default_factory=list) |
|
|
|
|
| def aeneas_beam_search( |
| input_ids: torch.Tensor, |
| model: PreTrainedModel, |
| tokenizer: PreTrainedTokenizerBase, |
| *, |
| beam_width: int = 5, |
| temperature: float = 1.0, |
| banned_token_ids: Optional[List[int]] = None, |
| ) -> List[Beam]: |
| """ |
| Возвращает `beam_width` бимов, отсортированных по убыванию log-probability. |
| """ |
| device = input_ids.device |
| mask_id = tokenizer.mask_token_id |
| banned = set(banned_token_ids or []) |
|
|
| |
| beams: List[Beam] = [Beam(input_ids=input_ids.clone())] |
|
|
| |
| n_masks = (input_ids == mask_id).sum().item() |
|
|
| with torch.no_grad(): |
| for _ in range(n_masks): |
| candidates: List[Beam] = [] |
|
|
| for beam in beams: |
| mask_positions = (beam.input_ids == mask_id).nonzero( |
| as_tuple=True)[0].tolist() |
|
|
| if not mask_positions: |
| candidates.append(beam) |
| continue |
|
|
| |
| logits = model( |
| beam.input_ids.unsqueeze(0) |
| ).logits[0] |
|
|
| |
| |
| best_pos = None |
| best_conf = -1.0 |
|
|
| for pos in mask_positions: |
| pos_logits = logits[pos] / max(temperature, 1e-6) |
| if banned: |
| pos_logits = pos_logits.clone() |
| for tid in banned: |
| if tid < pos_logits.shape[-1]: |
| pos_logits[tid] = float("-inf") |
| max_prob = pos_logits.softmax(dim=-1).max().item() |
| if max_prob > best_conf: |
| best_conf = max_prob |
| best_pos = pos |
|
|
| |
| pos_logits = logits[best_pos] / max(temperature, 1e-6) |
| if banned: |
| pos_logits = pos_logits.clone() |
| for tid in banned: |
| if tid < pos_logits.shape[-1]: |
| pos_logits[tid] = float("-inf") |
|
|
| probs = pos_logits.softmax(dim=-1) |
| top_probs, top_ids = probs.topk(beam_width) |
|
|
| for prob, token_id in zip(top_probs.tolist(), top_ids.tolist()): |
| if prob <= 0: |
| continue |
| new_ids = beam.input_ids.clone() |
| new_ids[best_pos] = token_id |
| candidates.append(Beam( |
| input_ids = new_ids, |
| log_prob = beam.log_prob + math.log(prob + 1e-12), |
| filled = beam.filled + [(best_pos, token_id)], |
| )) |
|
|
| |
| beams = sorted(candidates, key=lambda b: b.log_prob, reverse=True) |
| beams = beams[:beam_width] |
|
|
| return beams |
|
|
|
|
| |
|
|
| def decode_beams( |
| beams: List[Beam], |
| original_ids: torch.Tensor, |
| tokenizer: PreTrainedTokenizerBase, |
| ) -> List[dict]: |
| """ |
| Превращает бимы в читаемый список словарей. |
| |
| Возвращает: |
| [ |
| { |
| "text": полностью восстановленный текст, |
| "filled_tokens": [(position, token_str), ...] в порядке заполнения, |
| "score": нормализованная вероятность (0..1), |
| "log_prob": суммарный log-prob, |
| }, |
| ... |
| ] |
| """ |
| results = [] |
| |
| log_probs = torch.tensor([b.log_prob for b in beams], dtype=torch.float) |
| scores = log_probs.softmax(dim=0).tolist() |
|
|
| for beam, score in zip(beams, scores): |
| text = tokenizer.decode(beam.input_ids, skip_special_tokens=True) |
|
|
| filled_tokens = [ |
| (pos, tokenizer.decode([tid], skip_special_tokens=True, |
| clean_up_tokenization_spaces=False).strip()) |
| for pos, tid in beam.filled |
| ] |
|
|
| results.append({ |
| "text": text, |
| "filled_tokens": filled_tokens, |
| "score": round(score, 4), |
| "log_prob": round(beam.log_prob, 4), |
| }) |
|
|
| return results |
|
|
|
|
| |
|
|
| def restore( |
| text: str, |
| model: PreTrainedModel, |
| tokenizer: PreTrainedTokenizerBase, |
| *, |
| beam_width: int = 5, |
| temperature: float = 1.0, |
| gap_token: str = "[GAP]", |
| max_length: int = 512, |
| ) -> List[dict]: |
| """ |
| Высокоуровневая обёртка: принимает строку с [MASK], возвращает список бимов. |
| |
| Args: |
| text: текст с одним или несколькими [MASK] токенами. |
| gap_token: токен пропуска — исключается из предсказаний. |
| beam_width: число бимов. |
| temperature: <1 делает распределение острее, >1 — мягче. |
| """ |
| device = next(model.parameters()).device |
|
|
| enc = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=max_length, |
| ) |
| input_ids = enc["input_ids"][0].to(device) |
|
|
| |
| banned = [] |
| if gap_token in tokenizer.get_vocab(): |
| banned.append(tokenizer.convert_tokens_to_ids(gap_token)) |
|
|
| beams = aeneas_beam_search( |
| input_ids, model, tokenizer, |
| beam_width=beam_width, |
| temperature=temperature, |
| banned_token_ids=banned, |
| ) |
|
|
| return decode_beams(beams, input_ids, tokenizer) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| import argparse |
| from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
|
| parser = argparse.ArgumentParser() |
| _HERE = Path(__file__).resolve().parent |
| parser.add_argument("--model", default=str(_HERE / "outputs/final_model")) |
| parser.add_argument("--text", default="поклоне ѿ [MASK] к ѥва про [MASK] ѡкупи") |
| parser.add_argument("--top_k", type=int, default=5) |
| parser.add_argument("--temp", type=float, default=1.0) |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = AutoTokenizer.from_pretrained(args.model) |
| model = AutoModelForMaskedLM.from_pretrained(args.model).to(device) |
| model.eval() |
|
|
| print(f"\nВход: {args.text}\n") |
| results = restore(args.text, model, tokenizer, |
| beam_width=args.top_k, temperature=args.temp) |
|
|
| for i, r in enumerate(results, 1): |
| print(f" [{i}] score={r['score']:.3f} log_prob={r['log_prob']:.3f}") |
| print(f" {r['text']}") |
| print(f" заполнено: {r['filled_tokens']}") |