| |
| import os |
| import pytorch_lightning as pl |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
|
|
| class GPT2QADataset(Dataset): |
| ''' |
| Dataset Used for yuyuan medical qa task. |
| Just surpport small datasets, when deal with large datasets it may be slowly. |
| for large datasets please use mmapdatasets(doing) |
| ''' |
|
|
| def __init__(self, data_path, name, args): |
| super().__init__() |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| args.pretrained_model_path) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'}) |
| self.data_size = os.path.getsize(data_path)/1024/1024/1024 |
| self.data_type_name = name |
| self.data = self.load_data(data_path) |
| self.max_seq_length = args.max_seq_length |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, index): |
| return self.encode(self.data[index]) |
|
|
| def load_data(self, data_path): |
| |
| if self.data_size <= 5: |
| with open(data_path, "rt", encoding='utf8') as f: |
| lines = f.readlines() |
| total_num = len(lines) |
| data_gen = lines |
| else: |
| data_gen = open(data_path, "rt", encoding='utf8') |
| total_num = None |
|
|
| data = [] |
| with tqdm(total=total_num, desc=f'{self.data_type_name}处理进度', mininterval=0.3) as bar: |
| for idx, line in enumerate(data_gen): |
| data.append(self.data_parse(line)) |
| bar.update() |
|
|
| if self.data_size > 5: |
| data_gen.close() |
| return data |
|
|
| def data_parse(self, line): |
| """ |
| 解析不同格式的数据 |
| """ |
| dic = eval(line.strip()) |
| return dic |
|
|
| def encode(self, item): |
| """ |
| 将数据转换成模型训练的输入 |
| """ |
| inputs_dict = self.tokenizer.encode_plus(item['Question']+item['answer'], |
| max_length=self.max_seq_length, padding='max_length', |
| truncation=True, return_tensors='pt') |
| target = inputs_dict['input_ids'] |
| labels = target.clone().detach() |
| labels[target == self.tokenizer.pad_token_id] = -100 |
| return { |
| "input_ids": inputs_dict['input_ids'].squeeze(), |
| "attention_mask": inputs_dict['attention_mask'].squeeze(), |
| "labels": labels.squeeze(), |
| "question": item['Question'], |
| "answer": item['answer'] |
| } |
|
|
|
|
| class GPT2QADataModel(pl.LightningDataModule): |
| @staticmethod |
| def add_data_specific_args(parent_args): |
| parser = parent_args.add_argument_group('GPT2QADataModel') |
| parser.add_argument('--data_dir', type=str, required=True) |
| parser.add_argument('--num_workers', default=2, type=int) |
| parser.add_argument('--train_data', default='train.txt', type=str) |
| parser.add_argument('--valid_data', default='valid.txt', type=str) |
| parser.add_argument('--test_data', default='test.txt', type=str) |
| parser.add_argument('--train_batchsize', type=int, required=True) |
| parser.add_argument('--valid_batchsize', type=int, required=True) |
| parser.add_argument('--max_seq_length', default=1024, type=int) |
| return parent_args |
|
|
| def __init__(self, args): |
| super().__init__() |
| self.args = args |
| self.train_batchsize = args.train_batchsize |
| self.valid_batchsize = args.valid_batchsize |
| if not args.do_eval_only: |
| self.train_data = GPT2QADataset(os.path.join( |
| args.data_dir, args.train_data), '训练集', args) |
| self.valid_data = GPT2QADataset(os.path.join( |
| args.data_dir, args.valid_data), '验证集', args) |
| self.test_data = GPT2QADataset(os.path.join( |
| args.data_dir, args.test_data), '测试集', args) |
|
|
| def train_dataloader(self): |
| return DataLoader( |
| self.train_data, shuffle=True, |
| batch_size=self.train_batchsize, |
| pin_memory=False, num_workers=self.args.num_workers) |
|
|
| def val_dataloader(self): |
| return DataLoader(self.valid_data, shuffle=False, |
| batch_size=self.valid_batchsize, |
| pin_memory=False, num_workers=self.args.num_workers) |
|
|
| def predict_dataloader(self): |
| return DataLoader(self.test_data, shuffle=False, |
| batch_size=self.valid_batchsize, pin_memory=False, |
| num_workers=self.args.num_workers) |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| modelfile = '/cognitive_comp/wuziwei/pretrained_model_hf/medical_v2' |
| datafile = '/cognitive_comp/wuziwei/task-data/medical_qa/medical_qa_train.txt' |
| parser = argparse.ArgumentParser(description='hf test', allow_abbrev=False) |
| group = parser.add_argument_group(title='test args') |
| group.add_argument('--pretrained-model-path', type=str, default=modelfile, |
| help='Number of transformer layers.') |
| group.add_argument('--max-seq-length', type=int, default=1024) |
| args = parser.parse_args() |
|
|
| testml = GPT2QADataset(datafile, 'medical_qa', args=args) |
|
|
| print(testml[10]) |
|
|