Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import pad_sequence | |
| from typing import Union, List, Tuple | |
| from src.models.model import Model | |
| ################################################################################ | |
| # Extends Model class for speech recognition, with optional decoding | |
| ################################################################################ | |
| class Decoder(object): | |
| """ | |
| Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e). Base | |
| class for decoder objects, which convert emitted frame-by-frame token | |
| probabilities into a string transcription. | |
| """ | |
| def __init__(self, | |
| labels: Union[List[str], Tuple[str]], | |
| sep_idx: int = None, | |
| blank_idx: int = 0): | |
| """ | |
| Parameters | |
| ---------- | |
| labels (list): character corresponding to each token index | |
| sep_idx (int): index corresponding to space / separating character | |
| blank_idx (int): index corresponding to blank '_' character | |
| """ | |
| self.labels = labels | |
| self.blank_idx = blank_idx | |
| self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) | |
| if sep_idx is None: | |
| # use out-of-bounds index for separating character | |
| sep_idx = len(labels) | |
| if ' ' in labels: | |
| sep_idx = labels.index(' ') | |
| elif '|' in labels: | |
| sep_idx = labels.index('|') | |
| self.sep_idx = sep_idx | |
| else: | |
| self.sep_idx = sep_idx | |
| def get_labels(self): | |
| return self.labels | |
| def get_sep_idx(self): | |
| return self.sep_idx | |
| def get_blank_idx(self): | |
| return self.blank_idx | |
| def __call__(self, emission: torch.Tensor, sizes=None): | |
| return self.decode(emission, sizes) | |
| def decode(self, emission: torch.Tensor, sizes=None): | |
| """ | |
| Decode emitted token probabilities to obtain a string transcription. | |
| Parameters | |
| ---------- | |
| emission (Tensor): shape (n_batch, n_frames, n_tokens) | |
| sizes (Tensor): length in frames of each emission in batch | |
| """ | |
| raise NotImplementedError | |
| class GreedyCTCDecoder(Decoder): | |
| """ | |
| A simple decoder module to map token probability sequences to transcripts. | |
| Decodes 'greedily' by selecting maximum-probability token at each time step. | |
| Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e). | |
| """ | |
| def __init__(self, | |
| labels: Union[List[str], Tuple[str]], | |
| sep_idx: int = None, | |
| blank_idx: int = 0): | |
| super().__init__(labels, sep_idx, blank_idx) | |
| def convert_to_strings(self, | |
| sequences, | |
| sizes=None, | |
| remove_repetitions=False, | |
| return_offsets=False): | |
| """ | |
| Given a list of sequences holding token numbers, return the | |
| corresponding strings. Optionally, collapse repeated token subsequences | |
| and return final length of each processed sequence. | |
| Parameters | |
| ---------- | |
| sequences (Tensor): shape (n_batch, n_frames); holds argmax token index | |
| for each frame | |
| sizes | |
| remove_repetitions | |
| return_offsets | |
| Returns | |
| ------- | |
| """ | |
| strings = [] | |
| offsets = [] if return_offsets else None | |
| for i, sequence in enumerate(sequences): | |
| seq_len = sizes[i] if sizes is not None else len(sequence) | |
| string, string_offsets = self.process_string(sequence, seq_len, remove_repetitions) | |
| strings.append(string) | |
| if return_offsets: | |
| offsets.append(string_offsets) | |
| if return_offsets: | |
| return strings, offsets | |
| else: | |
| return strings | |
| def process_string(self, | |
| sequence, | |
| size, | |
| remove_repetitions=False): | |
| string = '' | |
| offsets = [] | |
| for i in range(size): | |
| char = self.int_to_char[sequence[i].item()] | |
| if char != self.int_to_char[self.blank_idx]: | |
| # skip repeated characters if specified | |
| if remove_repetitions and i != 0 and \ | |
| char == self.int_to_char[sequence[i - 1].item()]: | |
| pass | |
| elif char == self.labels[self.sep_idx]: | |
| string += self.labels[self.sep_idx] | |
| offsets.append(i) | |
| else: | |
| string = string + char | |
| offsets.append(i) | |
| return string, torch.tensor(offsets, dtype=torch.int) | |
| def decode(self, emission, sizes=None): | |
| """ | |
| Returns the argmax decoding given the emitted token probabilities. | |
| According to connectionist temporal classification (CTC), removes | |
| repeated elements in the decoded token sequence, as well as blanks. | |
| Parameters | |
| ---------- | |
| emission (Tensor): shape (n_batch, n_frames, n_tokens) | |
| sizes (Tensor): length in frames of each emission in batch | |
| Returns | |
| ------- | |
| transcription (list[str]): string transcription for each item in batch | |
| offsets (??? frame index per character predicted | |
| """ | |
| if emission.ndim == 2: # require shape (n_batch, n_frames, n_tokens) | |
| emission = emission.unsqueeze(0) | |
| # compute max-probability label at each sequence index | |
| max_probs = torch.argmax(emission, dim=-1) # (n_batch, sequence_len) | |
| strings, offsets = self.convert_to_strings(max_probs, | |
| sizes, | |
| remove_repetitions=True, | |
| return_offsets=True) | |
| return strings, offsets | |
| class SpeechRecognitionModel(Model): | |
| def __init__(self, | |
| model: nn.Module, | |
| decoder: Decoder = None | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.model.eval() | |
| # ensure that list of viable tokens can be retrieved from wrapped model | |
| labels_method = getattr(self.model, "get_labels", None) | |
| labels_attr = getattr(self.model, "labels", None) | |
| if callable(labels_method): | |
| self._get_labels_fn = lambda: self.model.get_labels() | |
| elif labels_attr is not None: | |
| self._get_labels_fn = lambda: self.model.labels | |
| else: | |
| raise ValueError(f'Wrapped model must have method `.get_labels()`' | |
| f' or attribute `.labels`') | |
| # ensure that blank and separator tokens can be retrieved from wrapped | |
| # model | |
| sep_method = getattr(self.model, "get_sep_idx", None) | |
| sep_attr = getattr(self.model, "sep_idx", None) | |
| if callable(sep_method): | |
| self._get_sep_fn = lambda: self.model.get_sep_idx() | |
| elif sep_attr is not None: | |
| self._get_sep_fn = lambda: self.model.sep_idx | |
| else: | |
| raise ValueError(f'Wrapped model must have method `.get_sep_idx()`' | |
| f' or attribute `.sep_idx`') | |
| blank_method = getattr(self.model, "get_blank_idx", None) | |
| blank_attr = getattr(self.model, "blank_idx", None) | |
| if callable(blank_method): | |
| self._get_blank_fn = lambda: self.model.get_blank_idx() | |
| elif blank_attr is not None: | |
| self._get_blank_fn = lambda: self.model.blank_idx | |
| else: | |
| raise ValueError(f'Wrapped model must have method ' | |
| f'`.get_blank_idx()` or attribute `.blank_idx`') | |
| # initialize decoder | |
| if decoder is None: | |
| decoder = GreedyCTCDecoder( | |
| labels=self.get_labels(), | |
| blank_idx=self.get_blank_idx(), | |
| sep_idx=self.get_sep_idx() | |
| ) | |
| self.decoder = decoder | |
| # translate characters to token indices | |
| self.char_to_idx = {l: i for i, l in enumerate(decoder.labels)} | |
| def get_labels(self): | |
| """Retrieve a list of valid tokens""" | |
| return self._get_labels_fn() | |
| def get_blank_idx(self): | |
| """Return index of blank token""" | |
| return self._get_blank_fn() | |
| def get_sep_idx(self): | |
| """Return index of separator token""" | |
| return self._get_sep_fn() | |
| def forward(self, x: torch.Tensor): | |
| return self.model.forward(x) | |
| def transcribe(self, x: torch.Tensor, return_alignment: bool = False): | |
| if return_alignment: | |
| return self.decoder(self.model(x)) | |
| else: | |
| return self.decoder(self.model(x))[0] | |
| def load_weights(self, path: str): | |
| """ | |
| Load weights from checkpoint file | |
| """ | |
| # check if file exists | |
| if not path or not os.path.isfile(path): | |
| return | |
| model_state = self.model.state_dict() | |
| loaded_state = torch.load(path) | |
| for name, param in loaded_state.items(): | |
| origname = name | |
| if name not in model_state: | |
| print("{} is not in the model.".format(origname)) | |
| continue | |
| if model_state[name].size() != loaded_state[origname].size(): | |
| print( | |
| "Wrong parameter length: {}, model: {}, loaded: {}".format( | |
| origname, | |
| model_state[name].size(), | |
| loaded_state[origname].size() | |
| ) | |
| ) | |
| continue | |
| model_state[name].copy_(param) | |
| def extract_features( | |
| self, | |
| x: torch.Tensor | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Extract deep features. | |
| :param x: input | |
| :return: a list of tensors holding intermediate activations / features | |
| """ | |
| try: | |
| return self.model.extract_features(x) | |
| except AttributeError: | |
| return [] | |
| def _str_to_tensor(self, seq: str): | |
| token_indices = [self.char_to_idx[c] for c in seq] | |
| return torch.as_tensor(token_indices, dtype=torch.long) | |
| def match_predict(self, | |
| y_pred: Union[List[str], torch.Tensor], | |
| y_true: Union[List[str], torch.Tensor]): | |
| """ | |
| Determine whether (batched) target pairs are equivalent. | |
| """ | |
| n_batch = len(y_pred) | |
| y_true_lengths = None | |
| # convert ground-truth transcriptions to tensor form | |
| if isinstance(y_true, list): | |
| y_true = [self._str_to_tensor(t) for t in y_true] | |
| y_true_lengths = [t.shape[-1] for t in y_true] | |
| y_true = pad_sequence( | |
| y_true, | |
| batch_first=True | |
| ) # (n_batch, max_seq_len) | |
| if y_true_lengths is None: | |
| y_true_lengths = [y_true.shape[-1]] * n_batch | |
| # convert predicted transcriptions to tensor form | |
| if isinstance(y_pred, list): | |
| y_pred = [self._str_to_tensor(t) for t in y_pred] | |
| y_pred = pad_sequence( | |
| y_pred, | |
| batch_first=True | |
| ) # (n_batch, max_seq_len) | |
| length_diff = max(0, y_true.shape[-1] - y_pred.shape[-1]) | |
| if length_diff: | |
| y_pred = F.pad(y_pred, (0, length_diff)) | |
| matches = [] | |
| for i in range(n_batch): | |
| matches.append( | |
| torch.all( | |
| y_pred[i, ..., :y_true_lengths[i]] == y_true[i, ..., :y_true_lengths[i]] | |
| ) | |
| ) | |
| return torch.as_tensor(matches) | |
| """ | |
| # masked comparison | |
| use which one as dimension to select --- true or pred? | |
| pred lengths may be unnecessary! just select to true length | |
| """ | |