File size: 5,060 Bytes
ab6c03c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import os
import numpy as np

from transformers import BertTokenizer, BertModel, DNATokenizer
from process_pretrain_data import get_kmer_sentence


def format_attention(attention):
    squeezed = []
    for layer_attention in attention:
        # 1 x num_heads x seq_len x seq_len
        if len(layer_attention.shape) != 4:
            raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
                             "output_attentions=True when initializing your model.")
        squeezed.append(layer_attention.squeeze(0))
    # num_layers x num_heads x seq_len x seq_len
    return torch.stack(squeezed)

def get_attention_dna(model, tokenizer, sentence_a, start, end):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b=None, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    attention = model(input_ids)[-1]
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
    attn = format_attention(attention)
    attn_score = []
    for i in range(1, len(tokens)-1):
        attn_score.append(float(attn[start:end+1,:,0,i].sum()))
    return attn_score

def get_real_score(attention_scores, kmer, metric):
    counts = np.zeros([len(attention_scores)+kmer-1])
    real_scores = np.zeros([len(attention_scores)+kmer-1])

    if metric == "mean":
        for i, score in enumerate(attention_scores):
            for j in range(kmer):
                counts[i+j] += 1.0
                real_scores[i+j] += score

        real_scores = real_scores/counts
    else:
        pass

    return real_scores

SEQUENCE = "TGCCTGGCTTTTTGTAATTTTTGAAGAGACGGGGTTTTGCCATGATG"

def Visualize(args):
    if args.kmer == 0:
        KMER_LIST = [3,4,5,6]

        for kmer in KMER_LIST:
            tokenizer_name = 'dna' + str(kmer)
            model_path = os.path.join(args.model_path, str(kmer))
            model = BertModel.from_pretrained(model_path, output_attentions=True)
            tokenizer = DNATokenizer.from_pretrained(tokenizer_name, do_lower_case=False)
            raw_sentence = args.sequence if args.sequence else SEQUENCE
            sentence_a = get_kmer_sentence(raw_sentence, kmer)
            tokens = sentence_a.split()

            attention = get_attention_dna(model, tokenizer, sentence_a, start=args.start_layer, end=args.end_layer)
            attention_scores = np.array(attention).reshape(np.array(attention).shape[0],1)
            # attention_scores[0] = 0
            
            real_scores = get_real_score(attention_scores, kmer, args.metric)
            real_scores = real_scores / np.linalg.norm(real_scores)

            if kmer != KMER_LIST[0]:
                scores += real_scores.reshape(1, real_scores.shape[0])
            else:
                scores = real_scores.reshape(1, real_scores.shape[0])

    else:
        # load model and calculate attention
        tokenizer_name = 'dna' + str(args.kmer)
        model_path = args.model_path
        model = BertModel.from_pretrained(model_path, output_attentions=True)
        tokenizer = DNATokenizer.from_pretrained(tokenizer_name, do_lower_case=False)
        raw_sentence = args.sequence if args.sequence else SEQUENCE
        sentence_a = get_kmer_sentence(raw_sentence, args.kmer)
        tokens = sentence_a.split()

        attention = get_attention_dna(model, tokenizer, sentence_a, start=args.start_layer, end=args.end_layer)
        attention_scores = np.array(attention).reshape(np.array(attention).shape[0],1)
        # attention_scores[0] = 0
        
        real_scores = get_real_score(attention_scores, args.kmer, args.metric)
        scores = real_scores.reshape(1, real_scores.shape[0])
    
    ave = np.sum(scores)/scores.shape[1]
    print(ave)
    print(scores)

    # plot        
    sns.set()
    ax = sns.heatmap(scores, cmap='YlGnBu', vmin=0)
    plt.show()

    


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--kmer",
        default=0,
        type=int,
        help="K-mer",
    )
    parser.add_argument(
        "--model_path",
        default="/home/zhihan/dna/dna-transformers/examples/ft/690/p53-small/TAp73beta/3/",
        type=str,
        help="The path of the finetuned model",
    )
    parser.add_argument(
        "--start_layer",
        default=11,
        type=int,
        help="Which layer to start",
    )
    parser.add_argument(
        "--end_layer",
        default=11,
        type=int,
        help="which layer to end",
    )
    parser.add_argument(
        "--metric",
        default="mean",
        type=str,
        help="the metric used for integrate predicted kmer result to real result",
    )
    parser.add_argument(
        "--sequence",
        default=None,
        type=str,
        help="the sequence for visualize",
    )

    args = parser.parse_args()
    Visualize(args)
        


if __name__ == "__main__":
    main()