ESMFold2-Fast / esmfold2_msa_filter_sequences.py
lhallee's picture
Upload folder using huggingface_hub
b44701d verified
Raw
History Blame Contribute Delete
2.9 kB
import os
import tempfile
from pathlib import Path
import numpy as np
from scipy.spatial.distance import cdist
from .esmfold2_system import run_subprocess_with_errorcheck
def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]:
"""Greedily select sequences that either maximize or minimize hamming distance.
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
iteratively add sequences to the list with the maximum (minimum) average Hamming
distance to the existing set of sequences.
Args:
array (np.ndarray): Character array representing the sequences in the MSA
num_seqs (int): Number of sequences to select.
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
you're doing it to prove a point for a paper.
Returns:
list[int]: List of indices to select from the array
"""
assert mode in ("max", "min")
depth = array.shape[0]
if depth <= num_seqs:
return list(range(depth))
array = array.view(np.uint8)
optfunc = np.argmax if mode == "max" else np.argmin
all_indices = np.arange(depth)
indices = [0]
pairwise_distances = np.zeros((0, depth))
for _ in range(num_seqs - 1):
dist = cdist(array[indices[-1:]], array, "hamming")
pairwise_distances = np.concatenate([pairwise_distances, dist])
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
shifted_index = optfunc(shifted_distance)
index = np.delete(all_indices, indices)[shifted_index]
indices.append(index)
indices = sorted(indices)
return indices
def hhfilter(
sequences: list[str],
seqid: int = 90,
diff: int = 0,
cov: int = 0,
qid: int = 0,
qsc: float = -20.0,
binary: str = "hhfilter",
) -> list[int]:
with tempfile.TemporaryDirectory(
dir="/dev/shm" if os.path.exists("/dev/shm") else None
) as tempdirname:
tempdir = Path(tempdirname)
fasta_file = tempdir / "input.fasta"
fasta_file.write_text(
"\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences))
)
output_file = tempdir / "output.fasta"
command = " ".join(
[
f"{binary}",
f"-i {fasta_file}",
"-M a3m",
f"-o {output_file}",
f"-id {seqid}",
f"-diff {diff}",
f"-cov {cov}",
f"-qid {qid}",
f"-qsc {qsc}",
]
).split(" ")
run_subprocess_with_errorcheck(command, capture_output=True)
with output_file.open() as f:
indices = [int(line[1:].strip()) for line in f if line.startswith(">")]
return indices