SignMotionGPT / collators.py
rdz-falcon's picture
Deploy SignMotionGPT Demo with LFS
4bd136e
"""
Data collators with label masking for training
"""
import torch
class AssistantSpanCollator:
"""
Collator that masks labels to only train on assistant responses.
For where=="mot": labels only inside <MOT_BEGIN>...<MOT_END> in assistant
For where=="text": labels entire assistant span (for M2T tasks)
"""
def __init__(self, tokenizer, max_length):
self.tok = tokenizer
self.max_len = max_length
# Get special token IDs
self.im_start = self.tok.convert_tokens_to_ids("<|im_start|>")
self.im_end = self.tok.convert_tokens_to_ids("<|im_end|>")
self.mot_beg = self.tok.convert_tokens_to_ids("<MOT_BEGIN>")
self.mot_end = self.tok.convert_tokens_to_ids("<MOT_END>")
def __call__(self, examples):
texts = [e["text"] for e in examples]
wheres = [e["where"] for e in examples]
# Tokenize
enc = self.tok(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_len
)
input_ids = enc["input_ids"]
labels = input_ids.clone().fill_(-100)
# Apply label masking per example
for i, w in enumerate(wheres):
seq = input_ids[i]
# Find last <|im_start|> (start of assistant)
starts = (seq == self.im_start).nonzero(as_tuple=True)[0]
if starts.numel() == 0:
continue
a_start = int(starts[-1].item())
# Find corresponding <|im_end|>
sub = seq[a_start+1:]
ends = (sub == self.im_end).nonzero(as_tuple=True)[0]
a_end = (a_start + 1 + int(ends[0].item())) if ends.numel() > 0 else (seq.size(0) - 1)
if w == "text":
# Label entire assistant span
labels[i, a_start+1:a_end] = seq[a_start+1:a_end]
else:
# Label only motion tokens between <MOT_BEGIN> and <MOT_END>
asst = seq[a_start+1:a_end]
bpos = (asst == self.mot_beg).nonzero(as_tuple=True)[0]
epos = (asst == self.mot_end).nonzero(as_tuple=True)[0]
if bpos.numel() > 0 and epos.numel() > 0 and epos[0] >= bpos[0]:
b = a_start + 1 + int(bpos[0].item())
e = a_start + 1 + int(epos[0].item())
labels[i, b:e+1] = seq[b:e+1]
return {
"input_ids": input_ids,
"attention_mask": enc["attention_mask"],
"labels": labels
}