File size: 2,997 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from transformers import PreTrainedTokenizerBase 
from typing import Dict, List, Any

class DynamicPaddingDataCollater:
    def __init__(self, tokenizer: PreTrainedTokenizerBase):

        self.tokenizer = tokenizer

        if tokenizer.pad_token_id is None:
            print("Warning: Tokenizer does not have a pad_token_id. Using 0 for input_ids and attention_mask padding.")
            self.padding_value_input = 0
        else:
            self.padding_value_input = tokenizer.pad_token_id

        # labels 的填充值
        self.padding_value_label = tokenizer.pad_token_id

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:

        processed_features = []
        for feature in features:
            input_ids = feature["input_ids"]
            completion_mask = feature["completion_mask"]

            prompt_ids = [token for token, is_completion in zip(input_ids, completion_mask) if not is_completion]

            label_ids = [token for token, is_completion in zip(input_ids, completion_mask) if is_completion]

            processed_features.append({
                "prompt_ids": prompt_ids,
                "label_ids": label_ids,

                "original": feature 
            })

        max_prompt_len = max(len(f["prompt_ids"]) for f in processed_features)
        max_label_len = max(len(f["label_ids"]) for f in processed_features)

        padded_prompt_ids = []
        padded_input_attention_mask = []
        padded_label_ids = []
        padded_labels_attention_mask = []

        for feature in processed_features:

            prompt_ids = feature["prompt_ids"]
            label_ids = feature["label_ids"]
            

            num_input_pads = max_prompt_len - len(prompt_ids)
            padded_prompt_ids.append([self.padding_value_input] * num_input_pads + prompt_ids)

            input_attention_mask = [1] * len(prompt_ids)
            num_input_mask_pads = max_prompt_len - len(input_attention_mask)
            padded_input_attention_mask.append([0] * num_input_mask_pads + input_attention_mask)

            num_label_pads = max_label_len - len(label_ids)
            padded_label_ids.append(label_ids + [self.padding_value_label] * num_label_pads)
            
            labels_attention_mask = [1] * len(label_ids)
            num_label_mask_pads = max_label_len - len(labels_attention_mask)
            padded_labels_attention_mask.append(labels_attention_mask + [0] * num_label_mask_pads)
        
        batch = {
            "prompt_ids": torch.tensor(padded_prompt_ids, dtype=torch.long),
            "prompt_attention_mask": torch.tensor(padded_input_attention_mask, dtype=torch.long),
            "label_ids": torch.tensor(padded_label_ids, dtype=torch.long),
            "label_attention_mask": torch.tensor(padded_labels_attention_mask, dtype=torch.long),
        }
        
        batch["raw_samples"] = [f["original"] for f in processed_features]

        return batch