ProtT3_model / data_provider /llm_tuning_dm.py
yuccaaa's picture
Add files using upload-large-folder tool
4d12519 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pytorch_lightning import LightningDataModule
from data_provider.gal_helpers import escape_custom_split_sequence
from data_provider.stage1_dm import SwissProtDataset, OntoProteinDataset
from torch.utils.data import DataLoader, ConcatDataset
class LLMTuningCollater:
def __init__(self, tokenizer, text_max_len, prot_max_len, use_gal):
self.text_max_len = text_max_len
self.prot_max_len = prot_max_len
self.tokenizer = tokenizer
self.use_gal = use_gal
def __call__(self, batch):
prot_seqs, prompt_seqs, text_seqs, _ = zip(*batch)
prot_seqs = [prompt.format(p) for prompt, p in zip(prompt_seqs, prot_seqs)]
if self.use_gal:
prot_seqs = [escape_custom_split_sequence(p) for p in prot_seqs]
## deal with prompt
self.tokenizer.padding_side = 'left'
prot_batch = self.tokenizer(text=prot_seqs,
truncation=True,
padding='max_length',
add_special_tokens=True,
max_length=self.prot_max_len,
return_tensors='pt',
return_attention_mask=True)
self.tokenizer.padding_side = 'right'
text_batch = self.tokenizer(text=text_seqs,
truncation=True,
padding='max_length',
add_special_tokens=True,
max_length=self.text_max_len,
return_tensors='pt',
return_attention_mask=True)
return prot_batch, text_batch
class InferenceCollater:
def __init__(self, tokenizer, text_max_len, prot_max_len, use_gal):
self.text_max_len = text_max_len
self.prot_max_len = prot_max_len
self.tokenizer = tokenizer
self.use_gal = use_gal
def __call__(self, batch):
prot_seqs, prompt_seqs, text_seqs, indices = zip(*batch)
prot_seqs = [prompt.format(p) for prompt, p in zip(prompt_seqs, prot_seqs)]
if self.use_gal:
prot_seqs = [escape_custom_split_sequence(p) for p in prot_seqs]
## deal with prompt
self.tokenizer.padding_side = 'left'
prot_batch = self.tokenizer(text=prot_seqs,
truncation=True,
padding='max_length',
add_special_tokens=True,
max_length=self.prot_max_len,
return_tensors='pt',
return_attention_mask=True)
target_dict = {'targets': text_seqs, 'indices': indices}
return prot_batch, target_dict
class LLMTuningDM(LightningDataModule):
def __init__(
self,
root: str = 'data/',
args=None,
):
super().__init__()
self.batch_size = args.batch_size
self.inference_batch_size = args.inference_batch_size
self.num_workers = args.num_workers
self.prot_max_len = args.prot_max_len
self.text_max_len = args.text_max_len
if root.find('SwissProtV3') >= 0:
self.train_dataset = SwissProtDataset(root+'/train_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
self.val_dataset = SwissProtDataset(root+'/valid_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
self.test_dataset = SwissProtDataset(root+'/test_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
elif root.find('OntoProteinDatasetV2') >= 0:
self.train_dataset = OntoProteinDataset(root+'/train.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
self.val_dataset = OntoProteinDataset(root+'/valid.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
self.test_dataset = OntoProteinDataset(root+'/test.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
else:
raise NotImplementedError()
self.tokenizer = None
self.use_gal = args.llm_name.find('gal') >= 0
def init_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
def train_dataloader(self):
loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=False,
drop_last=True,
persistent_workers=False,
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
return loader
def val_dataloader(self):
val_loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
drop_last=False,
persistent_workers=False,
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
test_loader = DataLoader(
self.test_dataset,
batch_size=self.inference_batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
drop_last=False,
persistent_workers=False,
collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
return [val_loader, test_loader]
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("Data module")
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--inference_batch_size', type=int, default=4)
parser.add_argument('--root', type=str, default='data/SwissProtV3')
parser.add_argument('--text_max_len', type=int, default=128)
parser.add_argument('--prot_max_len', type=int, default=1024)
parser.add_argument('--q_max_len', type=int, default=1064)
parser.add_argument('--a_max_len', type=int, default=36)
parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. The protein has the following properties: ')
parser.add_argument('--filter_side_qa', action='store_true', default=False)
return parent_parser
class LLMTuningMixDM(LightningDataModule):
def __init__(
self,
root: str = 'data/',
args=None,
):
super().__init__()
self.batch_size = args.batch_size
self.inference_batch_size = args.inference_batch_size
self.num_workers = args.num_workers
self.prot_max_len = args.prot_max_len
self.text_max_len = args.text_max_len
train_dataset1 = SwissProtDataset(root+'/SwissProtV3/train_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
train_dataset2 = OntoProteinDataset(root+'/OntoProteinDatasetV2/train.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
self.train_dataset = ConcatDataset([train_dataset1, train_dataset2])
self.swiss_val_dataset = SwissProtDataset(root+'/SwissProtV3/valid_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
self.onto_val_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/valid.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
self.swiss_test_dataset = SwissProtDataset(root+'/SwissProtV3/test_set.json', prompt='[START_AMINO]{}[END_AMINO]. Swiss-Prot description: ', return_prompt=True)
self.onto_test_dataset = OntoProteinDataset(root+'/OntoProteinDatasetV2/test.txt', prompt='[START_AMINO]{}[END_AMINO]. Gene Ontology description: ', return_prompt=True)
self.tokenizer = None
self.use_gal = args.llm_name.find('gal') >= 0
def init_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
def train_dataloader(self):
loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=False,
drop_last=True,
persistent_workers=False,
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
return loader
def val_dataloader(self):
swiss_val_loader = DataLoader(
self.swiss_val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
drop_last=False,
persistent_workers=False,
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
swiss_test_loader = DataLoader(
self.swiss_test_dataset,
batch_size=self.inference_batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
drop_last=False,
persistent_workers=False,
collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
onto_val_loader = DataLoader(
self.onto_val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
drop_last=False,
persistent_workers=False,
collate_fn=LLMTuningCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
onto_test_loader = DataLoader(
self.onto_test_dataset,
batch_size=self.inference_batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
drop_last=False,
persistent_workers=False,
collate_fn=InferenceCollater(self.tokenizer, self.text_max_len, self.prot_max_len, self.use_gal),
)
return [swiss_val_loader, swiss_test_loader, onto_val_loader, onto_test_loader]
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("Data module")
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--inference_batch_size', type=int, default=4)
parser.add_argument('--root', type=str, default='data/SwissProtV3')
parser.add_argument('--text_max_len', type=int, default=128)
parser.add_argument('--prot_max_len', type=int, default=1024)
parser.add_argument('--q_max_len', type=int, default=1064)
parser.add_argument('--a_max_len', type=int, default=36)
parser.add_argument('--prompt', type=str, default='[START_AMINO]{}[END_AMINO]. The protein has the following properties: ')
parser.add_argument('--filter_side_qa', action='store_true', default=False)
return parent_parser
if __name__ == '__main__':
dataset = SwissProtDataset('../data/SwissProtV3/train_set.json')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/galactica-1.3b')
tokenizer.add_special_tokens({'pad_token': '<pad>'})
loader = DataLoader(
dataset,
batch_size=16,
shuffle=True,
num_workers=0,
pin_memory=False,
drop_last=True,
persistent_workers=False,
collate_fn=LLMTuningCollater(tokenizer, 128, 1024, True, '[START_AMINO]{}[END_AMINO].'),
)
for data in loader:
input()