import torch from taoTrain.data.async_loader import AsyncBatchIterator from taoTrain.data.sft_jsonl import SFTJSONLDataset from taoTrain.data.sft_utils import build_response_only_next_token_labels def test_response_only_labels_mask_target_token_not_input_token(): input_ids = [10, 20, 30, 40, 2, 0] mask = [0, 0, 1, 1, 0, 0] labels = build_response_only_next_token_labels(input_ids, mask) assert labels == [-100, 30, 40, -100, -100, -100] def test_sft_dataset_direct_path_matches_response_only_helper(): dataset = SFTJSONLDataset.__new__(SFTJSONLDataset) dataset.chunk_manager = None dataset._current_chunk_data = { "input_ids": [[10, 20, 30, 40, 2, 0]], "attention_mask": [[1, 1, 1, 1, 1, 0]], "mask": [[0, 0, 1, 1, 0, 0]], } sample = dataset[0] assert torch.equal(sample["labels"], torch.tensor([-100, 30, 40, -100, -100, -100])) class _OneChunkQueue: def __init__(self): self._chunk = { "input_ids": [[10, 20, 30, 40, 2, 0]], "attention_mask": [[1, 1, 1, 1, 1, 0]], "mask": [[0, 0, 1, 1, 0, 0]], } self._returned = False self._next_chunk_idx = 0 self._chunk_order = [0] self._threads = [object()] def get_next_chunk(self, timeout=None): if self._returned: return None self._returned = True return self._chunk @property def is_exhausted(self): return self._returned def shutdown(self, wait=True): return None def __len__(self): return 1 def test_async_sft_loader_matches_direct_dataset_labels(): iterator = AsyncBatchIterator( tokenization_queue=_OneChunkQueue(), batch_size=1, device=torch.device("cpu"), drop_last=True, ) batch = next(iter(iterator)) assert torch.equal(batch["labels"], torch.tensor([[-100, 30, 40, -100, -100, -100]]))