TaoNet-mini-T2 / code /TaoTrain /tests /test_sft_masking.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
import torch
from taoTrain.data.async_loader import AsyncBatchIterator
from taoTrain.data.sft_jsonl import SFTJSONLDataset
from taoTrain.data.sft_utils import build_response_only_next_token_labels
def test_response_only_labels_mask_target_token_not_input_token():
input_ids = [10, 20, 30, 40, 2, 0]
mask = [0, 0, 1, 1, 0, 0]
labels = build_response_only_next_token_labels(input_ids, mask)
assert labels == [-100, 30, 40, -100, -100, -100]
def test_sft_dataset_direct_path_matches_response_only_helper():
dataset = SFTJSONLDataset.__new__(SFTJSONLDataset)
dataset.chunk_manager = None
dataset._current_chunk_data = {
"input_ids": [[10, 20, 30, 40, 2, 0]],
"attention_mask": [[1, 1, 1, 1, 1, 0]],
"mask": [[0, 0, 1, 1, 0, 0]],
}
sample = dataset[0]
assert torch.equal(sample["labels"], torch.tensor([-100, 30, 40, -100, -100, -100]))
class _OneChunkQueue:
def __init__(self):
self._chunk = {
"input_ids": [[10, 20, 30, 40, 2, 0]],
"attention_mask": [[1, 1, 1, 1, 1, 0]],
"mask": [[0, 0, 1, 1, 0, 0]],
}
self._returned = False
self._next_chunk_idx = 0
self._chunk_order = [0]
self._threads = [object()]
def get_next_chunk(self, timeout=None):
if self._returned:
return None
self._returned = True
return self._chunk
@property
def is_exhausted(self):
return self._returned
def shutdown(self, wait=True):
return None
def __len__(self):
return 1
def test_async_sft_loader_matches_direct_dataset_labels():
iterator = AsyncBatchIterator(
tokenization_queue=_OneChunkQueue(),
batch_size=1,
device=torch.device("cpu"),
drop_last=True,
)
batch = next(iter(iterator))
assert torch.equal(batch["labels"], torch.tensor([[-100, 30, 40, -100, -100, -100]]))