import os import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from typing import Union, List, Tuple from src.models.model import Model ################################################################################ # Extends Model class for speech recognition, with optional decoding ################################################################################ class Decoder(object): """ Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e). Base class for decoder objects, which convert emitted frame-by-frame token probabilities into a string transcription. """ def __init__(self, labels: Union[List[str], Tuple[str]], sep_idx: int = None, blank_idx: int = 0): """ Parameters ---------- labels (list): character corresponding to each token index sep_idx (int): index corresponding to space / separating character blank_idx (int): index corresponding to blank '_' character """ self.labels = labels self.blank_idx = blank_idx self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) if sep_idx is None: # use out-of-bounds index for separating character sep_idx = len(labels) if ' ' in labels: sep_idx = labels.index(' ') elif '|' in labels: sep_idx = labels.index('|') self.sep_idx = sep_idx else: self.sep_idx = sep_idx def get_labels(self): return self.labels def get_sep_idx(self): return self.sep_idx def get_blank_idx(self): return self.blank_idx def __call__(self, emission: torch.Tensor, sizes=None): return self.decode(emission, sizes) def decode(self, emission: torch.Tensor, sizes=None): """ Decode emitted token probabilities to obtain a string transcription. Parameters ---------- emission (Tensor): shape (n_batch, n_frames, n_tokens) sizes (Tensor): length in frames of each emission in batch """ raise NotImplementedError class GreedyCTCDecoder(Decoder): """ A simple decoder module to map token probability sequences to transcripts. Decodes 'greedily' by selecting maximum-probability token at each time step. Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e). """ def __init__(self, labels: Union[List[str], Tuple[str]], sep_idx: int = None, blank_idx: int = 0): super().__init__(labels, sep_idx, blank_idx) def convert_to_strings(self, sequences, sizes=None, remove_repetitions=False, return_offsets=False): """ Given a list of sequences holding token numbers, return the corresponding strings. Optionally, collapse repeated token subsequences and return final length of each processed sequence. Parameters ---------- sequences (Tensor): shape (n_batch, n_frames); holds argmax token index for each frame sizes remove_repetitions return_offsets Returns ------- """ strings = [] offsets = [] if return_offsets else None for i, sequence in enumerate(sequences): seq_len = sizes[i] if sizes is not None else len(sequence) string, string_offsets = self.process_string(sequence, seq_len, remove_repetitions) strings.append(string) if return_offsets: offsets.append(string_offsets) if return_offsets: return strings, offsets else: return strings def process_string(self, sequence, size, remove_repetitions=False): string = '' offsets = [] for i in range(size): char = self.int_to_char[sequence[i].item()] if char != self.int_to_char[self.blank_idx]: # skip repeated characters if specified if remove_repetitions and i != 0 and \ char == self.int_to_char[sequence[i - 1].item()]: pass elif char == self.labels[self.sep_idx]: string += self.labels[self.sep_idx] offsets.append(i) else: string = string + char offsets.append(i) return string, torch.tensor(offsets, dtype=torch.int) def decode(self, emission, sizes=None): """ Returns the argmax decoding given the emitted token probabilities. According to connectionist temporal classification (CTC), removes repeated elements in the decoded token sequence, as well as blanks. Parameters ---------- emission (Tensor): shape (n_batch, n_frames, n_tokens) sizes (Tensor): length in frames of each emission in batch Returns ------- transcription (list[str]): string transcription for each item in batch offsets (??? frame index per character predicted """ if emission.ndim == 2: # require shape (n_batch, n_frames, n_tokens) emission = emission.unsqueeze(0) # compute max-probability label at each sequence index max_probs = torch.argmax(emission, dim=-1) # (n_batch, sequence_len) strings, offsets = self.convert_to_strings(max_probs, sizes, remove_repetitions=True, return_offsets=True) return strings, offsets class SpeechRecognitionModel(Model): def __init__(self, model: nn.Module, decoder: Decoder = None ): super().__init__() self.model = model self.model.eval() # ensure that list of viable tokens can be retrieved from wrapped model labels_method = getattr(self.model, "get_labels", None) labels_attr = getattr(self.model, "labels", None) if callable(labels_method): self._get_labels_fn = lambda: self.model.get_labels() elif labels_attr is not None: self._get_labels_fn = lambda: self.model.labels else: raise ValueError(f'Wrapped model must have method `.get_labels()`' f' or attribute `.labels`') # ensure that blank and separator tokens can be retrieved from wrapped # model sep_method = getattr(self.model, "get_sep_idx", None) sep_attr = getattr(self.model, "sep_idx", None) if callable(sep_method): self._get_sep_fn = lambda: self.model.get_sep_idx() elif sep_attr is not None: self._get_sep_fn = lambda: self.model.sep_idx else: raise ValueError(f'Wrapped model must have method `.get_sep_idx()`' f' or attribute `.sep_idx`') blank_method = getattr(self.model, "get_blank_idx", None) blank_attr = getattr(self.model, "blank_idx", None) if callable(blank_method): self._get_blank_fn = lambda: self.model.get_blank_idx() elif blank_attr is not None: self._get_blank_fn = lambda: self.model.blank_idx else: raise ValueError(f'Wrapped model must have method ' f'`.get_blank_idx()` or attribute `.blank_idx`') # initialize decoder if decoder is None: decoder = GreedyCTCDecoder( labels=self.get_labels(), blank_idx=self.get_blank_idx(), sep_idx=self.get_sep_idx() ) self.decoder = decoder # translate characters to token indices self.char_to_idx = {l: i for i, l in enumerate(decoder.labels)} def get_labels(self): """Retrieve a list of valid tokens""" return self._get_labels_fn() def get_blank_idx(self): """Return index of blank token""" return self._get_blank_fn() def get_sep_idx(self): """Return index of separator token""" return self._get_sep_fn() def forward(self, x: torch.Tensor): return self.model.forward(x) def transcribe(self, x: torch.Tensor, return_alignment: bool = False): if return_alignment: return self.decoder(self.model(x)) else: return self.decoder(self.model(x))[0] def load_weights(self, path: str): """ Load weights from checkpoint file """ # check if file exists if not path or not os.path.isfile(path): return model_state = self.model.state_dict() loaded_state = torch.load(path) for name, param in loaded_state.items(): origname = name if name not in model_state: print("{} is not in the model.".format(origname)) continue if model_state[name].size() != loaded_state[origname].size(): print( "Wrong parameter length: {}, model: {}, loaded: {}".format( origname, model_state[name].size(), loaded_state[origname].size() ) ) continue model_state[name].copy_(param) def extract_features( self, x: torch.Tensor ) -> List[torch.Tensor]: """ Extract deep features. :param x: input :return: a list of tensors holding intermediate activations / features """ try: return self.model.extract_features(x) except AttributeError: return [] def _str_to_tensor(self, seq: str): token_indices = [self.char_to_idx[c] for c in seq] return torch.as_tensor(token_indices, dtype=torch.long) def match_predict(self, y_pred: Union[List[str], torch.Tensor], y_true: Union[List[str], torch.Tensor]): """ Determine whether (batched) target pairs are equivalent. """ n_batch = len(y_pred) y_true_lengths = None # convert ground-truth transcriptions to tensor form if isinstance(y_true, list): y_true = [self._str_to_tensor(t) for t in y_true] y_true_lengths = [t.shape[-1] for t in y_true] y_true = pad_sequence( y_true, batch_first=True ) # (n_batch, max_seq_len) if y_true_lengths is None: y_true_lengths = [y_true.shape[-1]] * n_batch # convert predicted transcriptions to tensor form if isinstance(y_pred, list): y_pred = [self._str_to_tensor(t) for t in y_pred] y_pred = pad_sequence( y_pred, batch_first=True ) # (n_batch, max_seq_len) length_diff = max(0, y_true.shape[-1] - y_pred.shape[-1]) if length_diff: y_pred = F.pad(y_pred, (0, length_diff)) matches = [] for i in range(n_batch): matches.append( torch.all( y_pred[i, ..., :y_true_lengths[i]] == y_true[i, ..., :y_true_lengths[i]] ) ) return torch.as_tensor(matches) """ # masked comparison use which one as dimension to select --- true or pred? pred lengths may be unnecessary! just select to true length """