Added msa.py utils
Browse files
app.py
CHANGED
|
@@ -64,12 +64,17 @@ def msa_embed(msa):
|
|
| 64 |
msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
|
| 65 |
msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
|
| 66 |
|
| 67 |
-
|
|
|
|
| 68 |
temp = temp[12][:,:,0,:]
|
| 69 |
temp = torch.mean(temp,(0,1))
|
| 70 |
return temp
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def download_data_if_required():
|
| 74 |
url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
|
| 75 |
fps = [pg.trained_model_fp]
|
|
|
|
| 64 |
msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
|
| 65 |
msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
|
| 66 |
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
temp = msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
|
| 69 |
temp = temp[12][:,:,0,:]
|
| 70 |
temp = torch.mean(temp,(0,1))
|
| 71 |
return temp
|
| 72 |
|
| 73 |
|
| 74 |
+
def go_embed(terms):
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
|
| 78 |
def download_data_if_required():
|
| 79 |
url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
|
| 80 |
fps = [pg.trained_model_fp]
|
msa.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import itertools
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
|
| 5 |
+
import string
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from scipy.spatial.distance import squareform, pdist, cdist
|
| 10 |
+
from Bio import SeqIO
|
| 11 |
+
#import biotite.structure as bs
|
| 12 |
+
#from biotite.structure.io.pdbx import PDBxFile, get_structure
|
| 13 |
+
#from biotite.database import rcsb
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# This is an efficient way to delete lowercase characters and insertion characters from a string
|
| 19 |
+
deletekeys = dict.fromkeys(string.ascii_lowercase)
|
| 20 |
+
deletekeys["."] = None
|
| 21 |
+
deletekeys["*"] = None
|
| 22 |
+
translation = str.maketrans(deletekeys)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_sequence(filename: str) -> Tuple[str, str]:
|
| 26 |
+
""" Reads the first (reference) sequences from a fasta or MSA file."""
|
| 27 |
+
record = next(SeqIO.parse(filename, "fasta"))
|
| 28 |
+
return record.description, str(record.seq)
|
| 29 |
+
|
| 30 |
+
def remove_insertions(sequence: str) -> str:
|
| 31 |
+
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
|
| 32 |
+
return sequence.translate(translation)
|
| 33 |
+
|
| 34 |
+
def read_msa(filename: str) -> List[Tuple[str, str]]:
|
| 35 |
+
""" Reads the sequences from an MSA file, automatically removes insertions."""
|
| 36 |
+
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
|
| 40 |
+
"""
|
| 41 |
+
Select sequences from the MSA to maximize the hamming distance
|
| 42 |
+
Alternatively, can use hhfilter
|
| 43 |
+
"""
|
| 44 |
+
assert mode in ("max", "min")
|
| 45 |
+
if len(msa) <= num_seqs:
|
| 46 |
+
return msa
|
| 47 |
+
|
| 48 |
+
array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)
|
| 49 |
+
|
| 50 |
+
optfunc = np.argmax if mode == "max" else np.argmin
|
| 51 |
+
all_indices = np.arange(len(msa))
|
| 52 |
+
indices = [0]
|
| 53 |
+
pairwise_distances = np.zeros((0, len(msa)))
|
| 54 |
+
for _ in range(num_seqs - 1):
|
| 55 |
+
dist = cdist(array[indices[-1:]], array, "hamming")
|
| 56 |
+
pairwise_distances = np.concatenate([pairwise_distances, dist])
|
| 57 |
+
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
|
| 58 |
+
shifted_index = optfunc(shifted_distance)
|
| 59 |
+
index = np.delete(all_indices, indices)[shifted_index]
|
| 60 |
+
indices.append(index)
|
| 61 |
+
indices = sorted(indices)
|
| 62 |
+
return [msa[idx] for idx in indices]
|