Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset | |
| """ | |
| import argparse | |
| import glob | |
| import os | |
| import os.path as osp | |
| import pprint | |
| import soundfile as sf | |
| import torch | |
| import fairseq | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| try: | |
| import tqdm | |
| except: | |
| print("Install tqdm to use --log-format=tqdm") | |
| class FilesDataset: | |
| def __init__(self, files, labels): | |
| self.files = files | |
| if labels and osp.exists(labels): | |
| with open(labels, "r") as lbl_f: | |
| self.labels = [line.rstrip() for line in lbl_f] | |
| else: | |
| self.labels = labels | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, index): | |
| fname = self.files[index] | |
| wav, sr = sf.read(fname) | |
| assert sr == 16000 | |
| wav = torch.from_numpy(wav).float() | |
| lbls = None | |
| if self.labels: | |
| if isinstance(self.labels, str): | |
| lbl_file = osp.splitext(fname)[0] + "." + self.labels | |
| with open(lbl_file, "r") as lblf: | |
| lbls = lblf.readline() | |
| assert lbls is not None | |
| else: | |
| lbls = self.labels[index] | |
| return wav, lbls | |
| def collate(self, batch): | |
| return batch | |
| class ArgTypes: | |
| def existing_path(arg): | |
| arg = str(arg) | |
| assert osp.exists(arg), f"File {arg} does not exist" | |
| return arg | |
| def mkdir(arg): | |
| arg = str(arg) | |
| os.makedirs(arg, exist_ok=True) | |
| return arg | |
| class DatasetWriter: | |
| def __init__(self): | |
| self.args = self.load_config() | |
| pprint.pprint(self.args.__dict__) | |
| self.model = self.load_model() | |
| def __getattr__(self, attr): | |
| return getattr(self.args, attr) | |
| def read_manifest(self, fname): | |
| with open(fname, "r") as fp: | |
| lines = fp.read().split("\n") | |
| root = lines.pop(0).strip() | |
| fnames = [ | |
| osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0 | |
| ] | |
| return fnames | |
| def process_splits(self): | |
| if self.args.shard is not None or self.args.num_shards is not None: | |
| assert self.args.shard is not None and self.args.num_shards is not None | |
| for split in self.splits: | |
| print(split) | |
| if self.extension == "tsv": | |
| datadir = osp.join(self.data_dir, f"{split}.{self.extension}") | |
| print("Reading manifest file: ", datadir) | |
| files = self.read_manifest(datadir) | |
| else: | |
| datadir = osp.join(self.data_dir, split, f"**/*.{self.extension}") | |
| files = glob.glob(datadir, recursive=True) | |
| assert len(files) > 0 | |
| if self.args.shard is not None: | |
| files = files[self.args.shard :: self.args.num_shards] | |
| lbls = [] | |
| with open(self.data_file(split), "w") as srcf: | |
| for line, lbl in self.iterate(files): | |
| print(line, file=srcf) | |
| if self.args.labels: | |
| lbls.append(lbl + "\n") | |
| if self.args.labels: | |
| assert all(a is not None for a in lbls) | |
| with open(self.lbl_file(split), "w") as lblf: | |
| lblf.writelines(lbls) | |
| def iterate(self, files): | |
| data = self.load_data(files) | |
| for samples in tqdm.tqdm(data, total=len(files) // 32): | |
| for wav, lbl in samples: | |
| x = wav.unsqueeze(0).float().cuda() | |
| div = 1 | |
| while x.size(-1) // div > self.args.max_size: | |
| div += 1 | |
| xs = x.chunk(div, dim=-1) | |
| result = [] | |
| for x in xs: | |
| torch.cuda.empty_cache() | |
| x = self.model.feature_extractor(x) | |
| if self.quantize_location == "encoder": | |
| with torch.no_grad(): | |
| _, idx = self.model.vector_quantizer.forward_idx(x) | |
| idx = idx.squeeze(0).cpu() | |
| else: | |
| with torch.no_grad(): | |
| z = self.model.feature_aggregator(x) | |
| _, idx = self.model.vector_quantizer.forward_idx(z) | |
| idx = idx.squeeze(0).cpu() | |
| result.append(idx) | |
| idx = torch.cat(result, dim=0) | |
| yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl | |
| def lbl_file(self, name): | |
| shard_part = "" if self.args.shard is None else f".{self.args.shard}" | |
| return osp.join(self.output_dir, f"{name}.lbl{shard_part}") | |
| def data_file(self, name): | |
| shard_part = "" if self.args.shard is None else f".{self.args.shard}" | |
| return osp.join(self.output_dir, f"{name}.src{shard_part}") | |
| def var_file(self): | |
| return osp.join(self.output_dir, f"vars.pt") | |
| def load_config(self): | |
| parser = argparse.ArgumentParser("Vector Quantized wav2vec features") | |
| # Model Arguments | |
| parser.add_argument("--checkpoint", type=ArgTypes.existing_path, required=True) | |
| parser.add_argument("--data-parallel", action="store_true") | |
| # Output Arguments | |
| parser.add_argument("--output-dir", type=ArgTypes.mkdir, required=True) | |
| # Data Arguments | |
| parser.add_argument("--data-dir", type=ArgTypes.existing_path, required=True) | |
| parser.add_argument("--splits", type=str, nargs="+", required=True) | |
| parser.add_argument("--extension", type=str, required=True) | |
| parser.add_argument("--labels", type=str, required=False) | |
| parser.add_argument("--shard", type=int, default=None) | |
| parser.add_argument("--num-shards", type=int, default=None) | |
| parser.add_argument("--max-size", type=int, default=1300000) | |
| # Logger Arguments | |
| parser.add_argument( | |
| "--log-format", type=str, choices=["none", "simple", "tqdm"] | |
| ) | |
| return parser.parse_args() | |
| def load_data(self, fnames): | |
| dataset = FilesDataset(fnames, self.args.labels) | |
| loader = DataLoader( | |
| dataset, batch_size=32, collate_fn=dataset.collate, num_workers=8 | |
| ) | |
| return loader | |
| def load_model(self): | |
| model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([self.checkpoint]) | |
| model = model[0] | |
| self.quantize_location = getattr(cfg.model, "vq", "encoder") | |
| model.eval().float() | |
| model.cuda() | |
| if self.data_parallel: | |
| model = nn.DataParallel(model) | |
| return model | |
| def __call__(self): | |
| self.process_splits() | |
| if hasattr(self.model.feature_extractor, "vars") and ( | |
| self.args.shard is None or self.args.shard == 0 | |
| ): | |
| vars = ( | |
| self.model.feature_extractor.vars.view( | |
| self.model.feature_extractor.banks, | |
| self.model.feature_extractor.num_vars, | |
| -1, | |
| ) | |
| .cpu() | |
| .detach() | |
| ) | |
| print("writing learned latent variable embeddings: ", vars.shape) | |
| torch.save(vars, self.var_file()) | |
| if __name__ == "__main__": | |
| write_data = DatasetWriter() | |
| write_data() | |
| print("Done.") | |