File size: 7,169 Bytes
4d12519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from data_provider.prot_qa_dm import PDBQADataset
from data_provider.gal_helpers import escape_custom_split_sequence


class LLMTuningProtQACollater(object):
    def __init__(self, tokenizer, q_max_len, a_max_len, use_gal, prompt):
        self.tokenizer = tokenizer
        self.q_max_len = q_max_len
        self.a_max_len = a_max_len
        self.use_gal = use_gal
        self.prompt = prompt
        assert prompt.find('{}') >= 0
        
    def __call__(self, batch):
        prot_seqs, questions, answers, q_types = zip(*batch)
        assert len(prot_seqs) == len(questions) == len(answers)
        questions = [self.prompt.format(prot_seqs[i], questions[i]) for i in range(len(prot_seqs))]

        if self.use_gal:
            questions = [escape_custom_split_sequence(q) for q in questions]
        answers = [a + '\n' for a in answers]
        if False:
            self.tokenizer.padding_side = 'left'
            q_batch = self.tokenizer(questions,
                                    truncation=True,
                                    padding='max_length',
                                    add_special_tokens=True,
                                    max_length=self.q_max_len,
                                    return_tensors='pt',
                                    return_attention_mask=True, 
                                    return_token_type_ids=False)
            self.tokenizer.padding_side = 'right'
            a_batch = self.tokenizer(answers,
                                    truncation=True,
                                    padding='max_length',
                                    add_special_tokens=True,
                                    max_length=self.a_max_len,
                                    return_tensors='pt',
                                    return_attention_mask=True, 
                                    return_token_type_ids=False)
            return q_batch, a_batch
        else:
            self.tokenizer.padding_side = 'right'
            qa_pair = [[q, a] for q, a in zip(questions, answers)]
            qa_batch = self.tokenizer(qa_pair,
                                      truncation=True,
                                      padding='max_length',
                                      add_special_tokens=True,
                                      max_length=self.q_max_len + self.a_max_len,
                                      return_tensors='pt',
                                      return_attention_mask=True,
                                      return_token_type_ids=True)
            return qa_batch


class InferenceCollater(object):
    def __init__(self, tokenizer, q_max_len, a_max_len, use_gal, prompt):
        self.tokenizer = tokenizer
        self.q_max_len = q_max_len
        self.a_max_len = a_max_len
        self.use_gal = use_gal
        self.prompt = prompt
        assert prompt.find('{}') >= 0
        
    def __call__(self, batch):
        prot_seqs, questions, answers, q_types, indices = zip(*batch)
        assert len(prot_seqs) == len(questions) == len(answers)
        questions = [self.prompt.format(prot_seqs[i], questions[i]) for i in range(len(prot_seqs))]

        if self.use_gal:
            questions = [escape_custom_split_sequence(q) for q in questions]
        answers = [a + '\n' for a in answers]
        self.tokenizer.padding_side = 'left'
        q_batch = self.tokenizer(questions,
                                 truncation=True,
                                 padding='max_length',
                                 add_special_tokens=True,
                                 max_length=self.q_max_len,
                                 return_tensors='pt',
                                 return_attention_mask=True, 
                                 return_token_type_ids=False)
        target_dict = {'targets': answers, 'q_types': q_types, 'indices': indices}
        return q_batch, target_dict


class LLMTuningProtQADM(LightningDataModule):
    def __init__(
        self,
        root: str = 'data/',
        args=None,
    ):
        super().__init__()
        self.args = args
        self.batch_size = args.batch_size
        self.inference_batch_size = args.inference_batch_size
        self.num_workers = args.num_workers
        self.q_max_len = args.q_max_len
        self.a_max_len = args.a_max_len
        self.prompt = args.prompt
        
        self.train_dataset = PDBQADataset(root, 'train.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
        self.val_dataset = PDBQADataset(root, 'val.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
        self.test_dataset = PDBQADataset(root, 'test.txt', "Question: {} Answer:", filter_side_qa=args.filter_side_qa)
        
        self.tokenizer = None
        self.use_gal = args.llm_name.find('gal') >= 0
    
    def init_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer

    def train_dataloader(self):
        loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=False,
            drop_last=True,
            persistent_workers=False,
            collate_fn=LLMTuningProtQACollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
        )
        return loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=False,
            drop_last=False,
            persistent_workers=False,
            collate_fn=LLMTuningProtQACollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
        )
        test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.inference_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=False,
            drop_last=False,
            persistent_workers=False,
            collate_fn=InferenceCollater(self.tokenizer, self.q_max_len, self.a_max_len, self.use_gal, self.prompt),
        )
        return [val_loader, test_loader]
    

    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("Data module")
        parser.add_argument('--num_workers', type=int, default=2)
        parser.add_argument('--batch_size', type=int, default=32)
        parser.add_argument('--inference_batch_size', type=int, default=4)
        parser.add_argument('--root', type=str, default='data/SwissProtV3')
        parser.add_argument('--q_max_len', type=int, default=1064)
        parser.add_argument('--a_max_len', type=int, default=36)
        parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. {}')
        parser.add_argument('--filter_side_qa', action='store_true', default=False)
        return parent_parser