File size: 2,813 Bytes
712d350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from dataclasses import dataclass
import numpy as np
import torch

from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment


class AbAlignment:

    def __init__(self, device = 'cpu', ncpu = 1):
        
        self.device = device
        self.ncpu = ncpu
        
    def number_sequences(self, seqs, chain = 'H', fragmented = False):
        if chain == 'HL':
            numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented = fragmented, n_jobs = self.ncpu)
        else:
            assert chain == 'HL', 'Currently "Align==True" only works for paired sequences. \nPlease use paired sequences or Align=False.'
            numbered_seqs, seqs, number_alignment = unpaired_msa_numbering(
                seqs, chain = chain, fragmented = fragmented, n_jobs = self.ncpu
            )
        
        return numbered_seqs, seqs, number_alignment
    
    def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment):
        
        aligned_encodings = np.concatenate(
            [[
                create_alignment(
                    res_embed, numbered_seq, seq, number_alignment
                ) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs)
            ]], axis=0
        )
        return aligned_encodings
        
        
    def reformat_subsets(
        self, 
        subset_list, 
        mode = 'seqcoding', 
        align = False,
        numbered_seqs = None, 
        seqs = None,
        number_alignment = None,
    ):
        
        if mode in [
            'seqcoding', 
            'restore',
            'pseudo_log_likelihood',
            'confidence'
        ]:
            return np.concatenate(subset_list)
        elif align:
            subset_list = [
                self.align_encodings(
                    subset, 
                    numbered_seqs[num*len(subset):(num+1)*len(subset)],
                    seqs[num*len(subset):(num+1)*len(subset)], 
                    number_alignment
                ) for num, subset in enumerate(subset_list)
            ]       
            
            subset = np.concatenate(subset_list)
            
            return aligned_results(
                aligned_seqs = [''.join(alist) for alist in subset[:,:,-1]],
                aligned_embeds = subset[:,:,:-1].astype(float),
                number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values
            ) 
    
        elif not align:
            return sum(subset_list, [])
        else:
            return np.concatenate(subset_list) # this needs to be changed
        

@dataclass
class aligned_results():
    """
    Dataclass used to store output.
    """
    
    aligned_seqs: None
    aligned_embeds: None
    number_alignment: None