| |
| |
| |
| |
|
|
| import os |
| import subprocess |
| import threading |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def fasta_file_path(prefix_path): |
| return prefix_path + ".fasta" |
|
|
|
|
| class FastaDataset(torch.utils.data.Dataset): |
| """ |
| For loading protein sequence datasets in the common FASTA data format |
| """ |
|
|
| def __init__(self, path: str, cache_indices=False): |
| self.fn = fasta_file_path(path) |
| self.threadlocal = threading.local() |
| self.cache = Path(f"{path}.fasta.idx.npy") |
| if cache_indices: |
| if self.cache.exists(): |
| self.offsets, self.sizes = np.load(self.cache) |
| else: |
| self.offsets, self.sizes = self._build_index(path) |
| np.save(self.cache, np.stack([self.offsets, self.sizes])) |
| else: |
| self.offsets, self.sizes = self._build_index(path) |
|
|
| def _get_file(self): |
| if not hasattr(self.threadlocal, "f"): |
| self.threadlocal.f = open(self.fn, "r") |
| return self.threadlocal.f |
|
|
| def __getitem__(self, idx): |
| f = self._get_file() |
| f.seek(self.offsets[idx]) |
| desc = f.readline().strip() |
| line = f.readline() |
| seq = "" |
| while line != "" and line[0] != ">": |
| seq += line.strip() |
| line = f.readline() |
| return desc, seq |
|
|
| def __len__(self): |
| return self.offsets.size |
|
|
| def _build_index(self, path: str): |
| |
| |
| path = fasta_file_path(path) |
| bytes_offsets = subprocess.check_output( |
| f"cat {path} | tqdm --bytes --total $(wc -c < {path})" |
| "| grep --byte-offset '^>' -o | cut -d: -f1", |
| shell=True, |
| ) |
| fasta_lengths = subprocess.check_output( |
| f"cat {path} | tqdm --bytes --total $(wc -c < {path})" |
| "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", |
| shell=True, |
| ) |
| bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") |
| sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") |
| return bytes_np, sizes_np |
|
|
| def __setstate__(self, state): |
| self.__dict__ = state |
| self.threadlocal = threading.local() |
|
|
| def __getstate__(self): |
| d = {} |
| for i, v in self.__dict__.items(): |
| if i != "threadlocal": |
| d[i] = v |
| return d |
|
|
| def __del__(self): |
| if hasattr(self.threadlocal, "f"): |
| self.threadlocal.f.close() |
| del self.threadlocal.f |
|
|
| @staticmethod |
| def exists(path): |
| return os.path.exists(fasta_file_path(path)) |
|
|
|
|
| class EncodedFastaDataset(FastaDataset): |
| """ |
| The FastaDataset returns raw sequences - this allows us to return |
| indices with a dictionary instead. |
| """ |
|
|
| def __init__(self, path, dictionary): |
| super().__init__(path, cache_indices=True) |
| self.dictionary = dictionary |
|
|
| def __getitem__(self, idx): |
| desc, seq = super().__getitem__(idx) |
| return self.dictionary.encode_line(seq, line_tokenizer=list).long() |
|
|