Spaces:
Running
Running
| import os | |
| import sys | |
| import ast | |
| import torch | |
| import itertools | |
| import collections | |
| sys.path.append(os.getcwd()) | |
| from main.library.speaker_diarization.speechbrain import if_main_process, ddp_barrier | |
| from main.library.speaker_diarization.features import register_checkpoint_hooks, mark_as_saver, mark_as_loader | |
| DEFAULT_UNK = "<unk>" | |
| DEFAULT_BOS = "<bos>" | |
| DEFAULT_EOS = "<eos>" | |
| DEFAULT_BLANK = "<blank>" | |
| class CategoricalEncoder: | |
| VALUE_SEPARATOR = " => " | |
| EXTRAS_SEPARATOR = "================\n" | |
| def __init__(self, starting_index=0, **special_labels): | |
| self.lab2ind = {} | |
| self.ind2lab = {} | |
| self.starting_index = starting_index | |
| self.handle_special_labels(special_labels) | |
| def handle_special_labels(self, special_labels): | |
| if "unk_label" in special_labels: self.add_unk(special_labels["unk_label"]) | |
| def __len__(self): | |
| return len(self.lab2ind) | |
| def from_saved(cls, path): | |
| obj = cls() | |
| obj.load(path) | |
| return obj | |
| def update_from_iterable(self, iterable, sequence_input=False): | |
| label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable) | |
| for label in label_iterator: | |
| self.ensure_label(label) | |
| def update_from_didataset(self, didataset, output_key, sequence_input=False): | |
| with didataset.output_keys_as([output_key]): | |
| self.update_from_iterable((data_point[output_key] for data_point in didataset), sequence_input=sequence_input) | |
| def limited_labelset_from_iterable(self, iterable, sequence_input=False, n_most_common=None, min_count=1): | |
| label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable) | |
| counts = collections.Counter(label_iterator) | |
| for label, count in counts.most_common(n_most_common): | |
| if count < min_count: break | |
| self.add_label(label) | |
| return counts | |
| def load_or_create(self, path, from_iterables=[], from_didatasets=[], sequence_input=False, output_key=None, special_labels={}): | |
| try: | |
| if if_main_process(): | |
| if not self.load_if_possible(path): | |
| for iterable in from_iterables: | |
| self.update_from_iterable(iterable, sequence_input) | |
| for didataset in from_didatasets: | |
| if output_key is None: raise ValueError | |
| self.update_from_didataset(didataset, output_key, sequence_input) | |
| self.handle_special_labels(special_labels) | |
| self.save(path) | |
| finally: | |
| ddp_barrier() | |
| self.load(path) | |
| def add_label(self, label): | |
| if label in self.lab2ind: raise KeyError | |
| index = self._next_index() | |
| self.lab2ind[label] = index | |
| self.ind2lab[index] = label | |
| return index | |
| def ensure_label(self, label): | |
| if label in self.lab2ind: return self.lab2ind[label] | |
| else: return self.add_label(label) | |
| def insert_label(self, label, index): | |
| if label in self.lab2ind: raise KeyError | |
| else: self.enforce_label(label, index) | |
| def enforce_label(self, label, index): | |
| index = int(index) | |
| if label in self.lab2ind: | |
| if index == self.lab2ind[label]: return | |
| else: del self.ind2lab[self.lab2ind[label]] | |
| if index in self.ind2lab: | |
| saved_label = self.ind2lab[index] | |
| moving_other = True | |
| else: moving_other = False | |
| self.lab2ind[label] = index | |
| self.ind2lab[index] = label | |
| if moving_other: | |
| new_index = self._next_index() | |
| self.lab2ind[saved_label] = new_index | |
| self.ind2lab[new_index] = saved_label | |
| def add_unk(self, unk_label=DEFAULT_UNK): | |
| self.unk_label = unk_label | |
| return self.add_label(unk_label) | |
| def _next_index(self): | |
| index = self.starting_index | |
| while index in self.ind2lab: | |
| index += 1 | |
| return index | |
| def is_continuous(self): | |
| indices = sorted(self.ind2lab.keys()) | |
| return self.starting_index in indices and all(j - i == 1 for i, j in zip(indices[:-1], indices[1:])) | |
| def encode_label(self, label, allow_unk=True): | |
| self._assert_len() | |
| try: | |
| return self.lab2ind[label] | |
| except KeyError: | |
| if hasattr(self, "unk_label") and allow_unk: return self.lab2ind[self.unk_label] | |
| elif hasattr(self, "unk_label") and not allow_unk: raise KeyError | |
| elif not hasattr(self, "unk_label") and allow_unk: raise KeyError | |
| else: raise KeyError | |
| def encode_label_torch(self, label, allow_unk=True): | |
| return torch.LongTensor([self.encode_label(label, allow_unk)]) | |
| def encode_sequence(self, sequence, allow_unk=True): | |
| self._assert_len() | |
| return [self.encode_label(label, allow_unk) for label in sequence] | |
| def encode_sequence_torch(self, sequence, allow_unk=True): | |
| return torch.LongTensor([self.encode_label(label, allow_unk) for label in sequence]) | |
| def decode_torch(self, x): | |
| self._assert_len() | |
| decoded = [] | |
| if x.ndim == 1: | |
| for element in x: | |
| decoded.append(self.ind2lab[int(element)]) | |
| else: | |
| for subtensor in x: | |
| decoded.append(self.decode_torch(subtensor)) | |
| return decoded | |
| def decode_ndim(self, x): | |
| self._assert_len() | |
| try: | |
| decoded = [] | |
| for subtensor in x: | |
| decoded.append(self.decode_ndim(subtensor)) | |
| return decoded | |
| except TypeError: | |
| return self.ind2lab[int(x)] | |
| def save(self, path): | |
| self._save_literal(path, self.lab2ind, self._get_extras()) | |
| def load(self, path): | |
| lab2ind, ind2lab, extras = self._load_literal(path) | |
| self.lab2ind = lab2ind | |
| self.ind2lab = ind2lab | |
| self._set_extras(extras) | |
| def load_if_possible(self, path, end_of_epoch=False): | |
| del end_of_epoch | |
| try: | |
| self.load(path) | |
| except FileNotFoundError: | |
| return False | |
| except (ValueError, SyntaxError): | |
| return False | |
| return True | |
| def expect_len(self, expected_len): | |
| self.expected_len = expected_len | |
| def ignore_len(self): | |
| self.expected_len = None | |
| def _assert_len(self): | |
| if hasattr(self, "expected_len"): | |
| if self.expected_len is None: return | |
| if len(self) != self.expected_len: raise RuntimeError | |
| else: | |
| self.ignore_len() | |
| return | |
| def _get_extras(self): | |
| extras = {"starting_index": self.starting_index} | |
| if hasattr(self, "unk_label"): extras["unk_label"] = self.unk_label | |
| return extras | |
| def _set_extras(self, extras): | |
| if "unk_label" in extras: self.unk_label = extras["unk_label"] | |
| self.starting_index = extras["starting_index"] | |
| def _save_literal(path, lab2ind, extras): | |
| with open(path, "w", encoding="utf-8") as f: | |
| for label, ind in lab2ind.items(): | |
| f.write(repr(label) + CategoricalEncoder.VALUE_SEPARATOR + str(ind) + "\n") | |
| f.write(CategoricalEncoder.EXTRAS_SEPARATOR) | |
| for key, value in extras.items(): | |
| f.write(repr(key) + CategoricalEncoder.VALUE_SEPARATOR + repr(value) + "\n") | |
| f.flush() | |
| def _load_literal(path): | |
| lab2ind, ind2lab, extras = {}, {}, {} | |
| with open(path, encoding="utf-8") as f: | |
| for line in f: | |
| if line == CategoricalEncoder.EXTRAS_SEPARATOR: break | |
| literal, ind = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1) | |
| label = ast.literal_eval(literal) | |
| lab2ind[label] = int(ind) | |
| ind2lab[ind] = label | |
| for line in f: | |
| literal_key, literal_value = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1) | |
| extras[ast.literal_eval(literal_key)] = ast.literal_eval(literal_value) | |
| return lab2ind, ind2lab, extras |