File size: 5,486 Bytes
712d350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import string, re
import numpy as np


def res_to_list(logits, seq):
    return logits[:len(seq)]

def res_to_seq(a, mode='mean'):
    """
    Function for how we go from n_values for each amino acid to n_values for each sequence.
    
    We leave out padding tokens.
    """
    
    if mode=='sum':
        return a[0:(int(a[-1]))].sum()
    
    elif mode=='mean':
        return a[0:(int(a[-1]))].mean()
    
    elif mode=='restore':
        return a[0][0:(int(a[-1]))]

def get_number_alignment(numbered_seqs):
    """
    Creates a number alignment from the anarci results.
    """
    import pandas as pd

    alist = [pd.DataFrame(aligned_seq, columns = [0,1,'resi']) for aligned_seq in numbered_seqs]
    unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0)
    max_alignment = get_max_alignment()    
    
    return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]]

def get_max_alignment():
    """
    Create maximum possible alignment for sorting
    """
    import pandas as pd

    sortlist = [[("<", "")]]
    for num in range(1, 128+1):
        if num in [33,61,112]:
            for char in string.ascii_uppercase[::-1]:
                sortlist.append([(num, char)])

            sortlist.append([(num,' ')])
        else:
            sortlist.append([(num,' ')])
            for char in string.ascii_uppercase:
                sortlist.append([(num, char)])
                
    return pd.DataFrame(sortlist + [[(">", "")]])


def paired_msa_numbering(ab_seqs, fragmented = False, n_jobs = 10):
    
    import pandas as pd
    
    tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs]

    numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering(
        [i[0] for i in tmp_seqs], 'H', fragmented = fragmented, n_jobs = n_jobs
    )
    numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering(
        [i[1] for i in tmp_seqs], 'L', fragmented = fragmented, n_jobs = n_jobs
    )
    
    number_alignment = pd.concat([
        number_alignment_heavy, 
        pd.DataFrame([[("|",""), "|"]]), 
        number_alignment_light]
    ).reset_index(drop=True)
    
    seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)]
    numbered_seqs = [
        heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light)
    ]
    
    return numbered_seqs, seqs, number_alignment


def unpaired_msa_numbering(seqs, chain = 'H', fragmented = False, n_jobs = 10):
    
    numbered_seqs = number_with_anarci(seqs, chain = chain, fragmented = fragmented, n_jobs = n_jobs)
    number_alignment = get_number_alignment(numbered_seqs)
    number_alignment[1] = chain   
    
    seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs]    
    return numbered_seqs, seqs, number_alignment


def number_with_anarci(seqs, chain = 'H', fragmented = False, n_jobs = 1):
    
    import anarci
    import pandas as pd

    anarci_out = anarci.run_anarci(
        pd.DataFrame(seqs).reset_index().values.tolist(), 
        ncpu=n_jobs, 
        scheme='imgt',
        allowed_species=['human', 'mouse'],
    )
    
    numbered_seqs = []
    for onarci in anarci_out[1]:            
        numbered_seq = []
        for i in onarci[0][0]:
            if i[1] != '-':
                numbered_seq.append((i[0], chain, i[1]))
            
        if fragmented:
            numbered_seqs.append(numbered_seq) 
        else:
            numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")])
    
    return numbered_seqs


def create_alignment(res_embeds, numbered_seqs, seq, number_alignment):
    
    import pandas as pd
    
    datadf = pd.DataFrame(numbered_seqs)
    sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2]

    idxs = np.where(sequence_alignment.values == '-')[0]
    idxs = [idx-num for num, idx in enumerate(idxs)]
    
    aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs , 0, axis=0))
    
    return pd.concat([aligned_embeds, sequence_alignment], axis=1).values


def get_spread_sequences(seq, spread, start_position):
    """
    Test sequences which are 8 positions shorter (position 10 + max CDR1 gap of 7) up to 2 positions longer (possible insertions).
    """
    spread_sequences = []

    for diff in range(start_position-8, start_position+2+1):
        spread_sequences.append('*'*diff+seq)
    
    return np.array(spread_sequences)

def get_sequences_from_anarci(out_anarci, max_position, spread):
    """
    Ensures correct masking on each side of sequence
    """
    
    if out_anarci == 'ANARCI_error':
        return np.array(['ANARCI-ERR']*spread)
    
    end_position = int(re.search(r'\d+', out_anarci[::-1]).group()[::-1])
    # Fixes ANARCI error of poor numbering of the CDR1 region
    start_position = int(re.search(r'\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+',
                                   out_anarci).group().split(',')[0]) - 1
    
    sequence = "".join(re.findall(r"(?i)[A-Z*]", "".join(re.findall(r'\),\s\'[A-Z*]', out_anarci))))

    sequence_j = ''.join(sequence).replace('-','').replace('X','*') + '*'*(max_position-int(end_position))

    return get_spread_sequences(sequence_j, spread, start_position)