File size: 2,898 Bytes
fb8a87c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import tempfile
from pathlib import Path

import numpy as np
from scipy.spatial.distance import cdist

from .esmfold2_system import run_subprocess_with_errorcheck


def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]:
    """Greedily select sequences that either maximize or minimize hamming distance.



    Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,

    iteratively add sequences to the list with the maximum (minimum) average Hamming

    distance to the existing set of sequences.



    Args:

        array (np.ndarray): Character array representing the sequences in the MSA

        num_seqs (int): Number of sequences to select.

        mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless

            you're doing it to prove a point for a paper.



    Returns:

        list[int]: List of indices to select from the array

    """
    assert mode in ("max", "min")
    depth = array.shape[0]
    if depth <= num_seqs:
        return list(range(depth))
    array = array.view(np.uint8)

    optfunc = np.argmax if mode == "max" else np.argmin
    all_indices = np.arange(depth)
    indices = [0]
    pairwise_distances = np.zeros((0, depth))
    for _ in range(num_seqs - 1):
        dist = cdist(array[indices[-1:]], array, "hamming")
        pairwise_distances = np.concatenate([pairwise_distances, dist])
        shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
        shifted_index = optfunc(shifted_distance)
        index = np.delete(all_indices, indices)[shifted_index]
        indices.append(index)
    indices = sorted(indices)
    return indices


def hhfilter(

    sequences: list[str],

    seqid: int = 90,

    diff: int = 0,

    cov: int = 0,

    qid: int = 0,

    qsc: float = -20.0,

    binary: str = "hhfilter",

) -> list[int]:
    with tempfile.TemporaryDirectory(
        dir="/dev/shm" if os.path.exists("/dev/shm") else None
    ) as tempdirname:
        tempdir = Path(tempdirname)
        fasta_file = tempdir / "input.fasta"
        fasta_file.write_text(
            "\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences))
        )
        output_file = tempdir / "output.fasta"
        command = " ".join(
            [
                f"{binary}",
                f"-i {fasta_file}",
                "-M a3m",
                f"-o {output_file}",
                f"-id {seqid}",
                f"-diff {diff}",
                f"-cov {cov}",
                f"-qid {qid}",
                f"-qsc {qsc}",
            ]
        ).split(" ")
        run_subprocess_with_errorcheck(command, capture_output=True)
        with output_file.open() as f:
            indices = [int(line[1:].strip()) for line in f if line.startswith(">")]
        return indices