Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import subprocess | |
| import threading | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| def fasta_file_path(prefix_path): | |
| return prefix_path + ".fasta" | |
| class FastaDataset(torch.utils.data.Dataset): | |
| """ | |
| For loading protein sequence datasets in the common FASTA data format | |
| """ | |
| def __init__(self, path: str, cache_indices=False): | |
| self.fn = fasta_file_path(path) | |
| self.threadlocal = threading.local() | |
| self.cache = Path(f"{path}.fasta.idx.npy") | |
| if cache_indices: | |
| if self.cache.exists(): | |
| self.offsets, self.sizes = np.load(self.cache) | |
| else: | |
| self.offsets, self.sizes = self._build_index(path) | |
| np.save(self.cache, np.stack([self.offsets, self.sizes])) | |
| else: | |
| self.offsets, self.sizes = self._build_index(path) | |
| def _get_file(self): | |
| if not hasattr(self.threadlocal, "f"): | |
| self.threadlocal.f = open(self.fn, "r") | |
| return self.threadlocal.f | |
| def __getitem__(self, idx): | |
| f = self._get_file() | |
| f.seek(self.offsets[idx]) | |
| desc = f.readline().strip() | |
| line = f.readline() | |
| seq = "" | |
| while line != "" and line[0] != ">": | |
| seq += line.strip() | |
| line = f.readline() | |
| return desc, seq | |
| def __len__(self): | |
| return self.offsets.size | |
| def _build_index(self, path: str): | |
| # Use grep and awk to get 100M/s on local SSD. | |
| # Should process your enormous 100G fasta in ~10 min single core... | |
| path = fasta_file_path(path) | |
| bytes_offsets = subprocess.check_output( | |
| f"cat {path} | tqdm --bytes --total $(wc -c < {path})" | |
| "| grep --byte-offset '^>' -o | cut -d: -f1", | |
| shell=True, | |
| ) | |
| fasta_lengths = subprocess.check_output( | |
| f"cat {path} | tqdm --bytes --total $(wc -c < {path})" | |
| "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", | |
| shell=True, | |
| ) | |
| bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") | |
| sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") | |
| return bytes_np, sizes_np | |
| def __setstate__(self, state): | |
| self.__dict__ = state | |
| self.threadlocal = threading.local() | |
| def __getstate__(self): | |
| d = {} | |
| for i, v in self.__dict__.items(): | |
| if i != "threadlocal": | |
| d[i] = v | |
| return d | |
| def __del__(self): | |
| if hasattr(self.threadlocal, "f"): | |
| self.threadlocal.f.close() | |
| del self.threadlocal.f | |
| def exists(path): | |
| return os.path.exists(fasta_file_path(path)) | |
| class EncodedFastaDataset(FastaDataset): | |
| """ | |
| The FastaDataset returns raw sequences - this allows us to return | |
| indices with a dictionary instead. | |
| """ | |
| def __init__(self, path, dictionary): | |
| super().__init__(path, cache_indices=True) | |
| self.dictionary = dictionary | |
| def __getitem__(self, idx): | |
| desc, seq = super().__getitem__(idx) | |
| return self.dictionary.encode_line(seq, line_tokenizer=list).long() | |