File size: 2,997 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
from transformers import PreTrainedTokenizerBase
from typing import Dict, List, Any
class DynamicPaddingDataCollater:
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
if tokenizer.pad_token_id is None:
print("Warning: Tokenizer does not have a pad_token_id. Using 0 for input_ids and attention_mask padding.")
self.padding_value_input = 0
else:
self.padding_value_input = tokenizer.pad_token_id
# labels 的填充值
self.padding_value_label = tokenizer.pad_token_id
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
processed_features = []
for feature in features:
input_ids = feature["input_ids"]
completion_mask = feature["completion_mask"]
prompt_ids = [token for token, is_completion in zip(input_ids, completion_mask) if not is_completion]
label_ids = [token for token, is_completion in zip(input_ids, completion_mask) if is_completion]
processed_features.append({
"prompt_ids": prompt_ids,
"label_ids": label_ids,
"original": feature
})
max_prompt_len = max(len(f["prompt_ids"]) for f in processed_features)
max_label_len = max(len(f["label_ids"]) for f in processed_features)
padded_prompt_ids = []
padded_input_attention_mask = []
padded_label_ids = []
padded_labels_attention_mask = []
for feature in processed_features:
prompt_ids = feature["prompt_ids"]
label_ids = feature["label_ids"]
num_input_pads = max_prompt_len - len(prompt_ids)
padded_prompt_ids.append([self.padding_value_input] * num_input_pads + prompt_ids)
input_attention_mask = [1] * len(prompt_ids)
num_input_mask_pads = max_prompt_len - len(input_attention_mask)
padded_input_attention_mask.append([0] * num_input_mask_pads + input_attention_mask)
num_label_pads = max_label_len - len(label_ids)
padded_label_ids.append(label_ids + [self.padding_value_label] * num_label_pads)
labels_attention_mask = [1] * len(label_ids)
num_label_mask_pads = max_label_len - len(labels_attention_mask)
padded_labels_attention_mask.append(labels_attention_mask + [0] * num_label_mask_pads)
batch = {
"prompt_ids": torch.tensor(padded_prompt_ids, dtype=torch.long),
"prompt_attention_mask": torch.tensor(padded_input_attention_mask, dtype=torch.long),
"label_ids": torch.tensor(padded_label_ids, dtype=torch.long),
"label_attention_mask": torch.tensor(padded_labels_attention_mask, dtype=torch.long),
}
batch["raw_samples"] = [f["original"] for f in processed_features]
return batch
|