File size: 2,686 Bytes
4bd136e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Data collators with label masking for training
"""
import torch


class AssistantSpanCollator:
    """
    Collator that masks labels to only train on assistant responses.
    
    For where=="mot": labels only inside <MOT_BEGIN>...<MOT_END> in assistant
    For where=="text": labels entire assistant span (for M2T tasks)
    """
    
    def __init__(self, tokenizer, max_length):
        self.tok = tokenizer
        self.max_len = max_length
        
        # Get special token IDs
        self.im_start = self.tok.convert_tokens_to_ids("<|im_start|>")
        self.im_end = self.tok.convert_tokens_to_ids("<|im_end|>")
        self.mot_beg = self.tok.convert_tokens_to_ids("<MOT_BEGIN>")
        self.mot_end = self.tok.convert_tokens_to_ids("<MOT_END>")
    
    def __call__(self, examples):
        texts = [e["text"] for e in examples]
        wheres = [e["where"] for e in examples]
        
        # Tokenize
        enc = self.tok(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_len
        )
        
        input_ids = enc["input_ids"]
        labels = input_ids.clone().fill_(-100)
        
        # Apply label masking per example
        for i, w in enumerate(wheres):
            seq = input_ids[i]
            
            # Find last <|im_start|> (start of assistant)
            starts = (seq == self.im_start).nonzero(as_tuple=True)[0]
            if starts.numel() == 0:
                continue
            
            a_start = int(starts[-1].item())
            
            # Find corresponding <|im_end|>
            sub = seq[a_start+1:]
            ends = (sub == self.im_end).nonzero(as_tuple=True)[0]
            a_end = (a_start + 1 + int(ends[0].item())) if ends.numel() > 0 else (seq.size(0) - 1)
            
            if w == "text":
                # Label entire assistant span
                labels[i, a_start+1:a_end] = seq[a_start+1:a_end]
            else:
                # Label only motion tokens between <MOT_BEGIN> and <MOT_END>
                asst = seq[a_start+1:a_end]
                bpos = (asst == self.mot_beg).nonzero(as_tuple=True)[0]
                epos = (asst == self.mot_end).nonzero(as_tuple=True)[0]
                
                if bpos.numel() > 0 and epos.numel() > 0 and epos[0] >= bpos[0]:
                    b = a_start + 1 + int(bpos[0].item())
                    e = a_start + 1 + int(epos[0].item())
                    labels[i, b:e+1] = seq[b:e+1]
        
        return {
            "input_ids": input_ids,
            "attention_mask": enc["attention_mask"],
            "labels": labels
        }