Spaces:
Paused
Paused
| from torch.utils.data import Dataset | |
| import torch | |
| import json | |
| import numpy as np | |
| class QADataset(Dataset): | |
| def __init__(self, data_path, tokenizer, max_source_length, max_target_length) -> None: | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.max_source_length = max_source_length | |
| self.max_target_length = max_target_length | |
| self.max_seq_length = self.max_source_length + self.max_target_length | |
| self.data = [] | |
| if data_path: | |
| with open(data_path, "r", encoding='utf-8') as f: | |
| for line in f: | |
| if not line or line == "": | |
| continue | |
| json_line = json.loads(line) | |
| question = json_line["question"] | |
| answer = json_line["answer"] | |
| self.data.append({ | |
| "question": question, | |
| "answer": answer | |
| }) | |
| print("data load , size:", len(self.data)) | |
| def preprocess(self, question, answer): | |
| messages = [ | |
| {"role": "system", "content": "你是一个医疗方面的专家,可以根据患者的问题进行解答。"}, | |
| {"role": "user", "content": question} | |
| ] | |
| prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| instruction = self.tokenizer(prompt, add_special_tokens=False, max_length=self.max_source_length) | |
| response = self.tokenizer(answer, add_special_tokens=False, max_length=self.max_target_length) | |
| input_ids = instruction["input_ids"] + response["input_ids"] + [self.tokenizer.pad_token_id] | |
| attention_mask = (instruction["attention_mask"] + response["attention_mask"] + [1]) | |
| labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [self.tokenizer.pad_token_id] | |
| if len(input_ids) > self.max_seq_length: | |
| input_ids = input_ids[:self.max_seq_length] | |
| attention_mask = attention_mask[:self.max_seq_length] | |
| labels = labels[:self.max_seq_length] | |
| return input_ids, attention_mask, labels | |
| def __getitem__(self, index): | |
| item_data = self.data[index] | |
| input_ids, attention_mask, labels = self.preprocess(**item_data) | |
| return { | |
| "input_ids": torch.LongTensor(np.array(input_ids)), | |
| "attention_mask": torch.LongTensor(np.array(attention_mask)), | |
| "labels": torch.LongTensor(np.array(labels)) | |
| } | |
| def __len__(self): | |
| return len(self.data) | |