Spaces:
Runtime error
Runtime error
| import torch | |
| from abc import ABC, abstractmethod | |
| from typing import List, Optional, Tuple | |
| from torch import Tensor | |
| from torch.nn.utils.rnn import pad_sequence | |
| class BaseTokenizer(ABC): | |
| def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: | |
| self._itos = specials_first + tuple(charset + '[UNK]') + specials_last | |
| self._stoi = {s: i for i, s in enumerate(self._itos)} | |
| def __len__(self): | |
| return len(self._itos) | |
| def _tok2ids(self, tokens: str) -> List[int]: | |
| return [self._stoi[s] for s in tokens] | |
| def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: | |
| tokens = [self._itos[i] for i in token_ids] | |
| return ''.join(tokens) if join else tokens | |
| def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: | |
| raise NotImplementedError | |
| def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: | |
| """Internal method which performs the necessary filtering prior to decoding.""" | |
| raise NotImplementedError | |
| def decode(self, token_dists: Tensor, beam_width: int = 1, raw: bool = False) -> Tuple[List[str], List[Tensor]]: | |
| if beam_width > 1: | |
| return self.beam_search_decode(token_dists, beam_width, raw) | |
| else: | |
| return self.greedy_decode(token_dists, raw) | |
| def greedy_decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: | |
| batch_tokens = [] | |
| batch_probs = [] | |
| for dist in token_dists: | |
| probs, ids = dist.max(-1) | |
| if not raw: | |
| probs, ids = self._filter(probs, ids) | |
| tokens = self._ids2tok(ids, not raw) | |
| batch_tokens.append(tokens) | |
| batch_probs.append(probs) | |
| return batch_tokens, batch_probs | |
| def beam_search_decode(self, token_dists: Tensor, beam_width: int, raw: bool) -> Tuple[List[str], List[Tensor]]: | |
| batch_tokens = [] | |
| batch_probs = [] | |
| for dist in token_dists: | |
| sequences = [([], 1.0)] | |
| for step_dist in dist: | |
| all_candidates = [] | |
| for seq, score in sequences: | |
| top_probs, top_ids = step_dist.topk(beam_width) | |
| for i in range(beam_width): | |
| candidate = (seq + [top_ids[i].item()], | |
| score * top_probs[i].item()) | |
| all_candidates.append(candidate) | |
| ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True) | |
| sequences = ordered[:beam_width] | |
| best_sequence, best_score = sequences[0] | |
| if not raw: | |
| best_score_tensor = torch.tensor([best_score]) | |
| best_sequence_tensor = torch.tensor(best_sequence) | |
| best_score_tensor, best_sequence = self._filter( | |
| best_score_tensor, best_sequence_tensor) | |
| best_score = best_score_tensor.item() | |
| tokens = self._ids2tok(best_sequence, not raw) | |
| batch_tokens.append(tokens) | |
| batch_probs.append(best_score) | |
| return batch_tokens, batch_probs | |
| class Tokenizer(BaseTokenizer): | |
| BOS = '[B]' | |
| EOS = '[E]' | |
| PAD = '[P]' | |
| def __init__(self, charset: str) -> None: | |
| specials_first = (self.EOS,) | |
| specials_last = (self.BOS, self.PAD) | |
| super().__init__(charset, specials_first, specials_last) | |
| self.eos_id, self.bos_id, self.pad_id = [ | |
| self._stoi[s] for s in specials_first + specials_last] | |
| def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: | |
| batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) | |
| for y in labels] | |
| return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) | |
| def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: | |
| ids = ids.tolist() | |
| try: | |
| eos_idx = ids.index(self.eos_id) | |
| except ValueError: | |
| eos_idx = len(ids) | |
| ids = ids[:eos_idx] | |
| probs = probs[:eos_idx + 1] | |
| return probs, ids | |