TaoNet-mini-T2 / code /TaoTrain /src /taoTrain /data /async_loader.py
StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Async batch iterator for training with background tokenization."""
from typing import Dict, List, Optional, Any, Iterator
import torch
from taoTrain.data.tokenization_queue import TokenizationQueue
from taoTrain.data.sft_utils import build_response_only_next_token_labels
class AsyncBatchIterator:
"""
Iterator that yields batches from a tokenization queue.
This allows batches to be consumed directly from the background tokenization
thread without waiting for all chunks to be tokenized upfront.
The iterator:
1. Pulls pre-tokenized chunks from the TokenizationQueue
2. Yields individual samples or batches
3. Handles movement to device (GPU/CPU) at batch level
4. Supports gradient accumulation
"""
def __init__(
self,
tokenization_queue: TokenizationQueue,
batch_size: int,
device: torch.device,
drop_last: bool = True,
gradient_accumulation_steps: int = 1,
):
"""
Initialize async batch iterator.
Args:
tokenization_queue: TokenizationQueue instance
batch_size: Batch size for yielding batches
device: torch.device to move batches to
drop_last: If True, drop last incomplete batch
gradient_accumulation_steps: For logging purposes (not used here)
"""
self.queue = tokenization_queue
self.batch_size = batch_size
self.device = device
self.drop_last = drop_last
self.gradient_accumulation_steps = gradient_accumulation_steps
# State for iteration
self._current_chunk: Optional[Dict[str, List]] = None
self._current_idx = 0
self._samples_yielded = 0
self._finished = False
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Return iterator (self)."""
# Reset state for new epoch
self._current_chunk = None
self._current_idx = 0
self._samples_yielded = 0
self._finished = False
# Reset tokenization queue for epochs 2+
if self.queue._next_chunk_idx > 0:
print(f"\n✓ Resetting TokenizationQueue for next epoch (cur_idx={self.queue._next_chunk_idx})")
self.queue.reset_for_next_epoch()
# Start tokenization threads once per iterator creation
if not self.queue._threads:
print("\n✓ Starting TokenizationQueue worker threads...")
self.queue.start()
else:
print(f"\n⚠ TokenizationQueue threads already running: {len(self.queue._threads)} active")
return self
def __next__(self) -> Dict[str, torch.Tensor]:
"""
Get next batch.
Yields:
Dict with 'input_ids', 'attention_mask', 'labels' (all as torch tensors on device)
Raises:
StopIteration: When no more batches available
"""
batch = self._get_next_batch()
if batch is None:
print("AsyncBatchIterator: No more batches available, stopping iteration.")
raise StopIteration
return batch
def _get_next_batch(self) -> Optional[Dict[str, torch.Tensor]]:
"""
Fetch and collate the next batch.
Returns:
Dict with batch tensors, or None if iteration exhausted
"""
batch_input_ids = []
batch_attention_masks = []
batch_labels = []
while len(batch_input_ids) < self.batch_size:
# Try to get next sample from current chunk
if self._current_chunk is None or self._current_idx >= len(self._current_chunk["input_ids"]):
# Need new chunk
self._current_chunk = self.queue.get_next_chunk(timeout=30.0) # 30s polling timeout
if self._current_chunk is None:
if not self.queue.is_exhausted:
continue
# Queue exhausted
chunk_count = self.queue._next_chunk_idx if hasattr(self.queue, '_next_chunk_idx') else 'unknown'
print(f"AsyncBatchIterator: No more chunks (processed {chunk_count}/{len(self.queue._chunk_order)})")
print(f"AsyncBatchIterator: Samples yielded so far: {self._samples_yielded}")
self._finished = True
break
self._current_idx = 0
# Get sample from current chunk
input_ids = self._current_chunk["input_ids"][self._current_idx]
attention_mask = self._current_chunk["attention_mask"][self._current_idx]
# Generate labels based on SFT or pretrain mode
if "mask" in self._current_chunk:
# SFT mode: use mask to determine which tokens to train on
# mask=0 → label=-100 (ignore), mask=1 → label=input_id (train on)
mask = self._current_chunk["mask"][self._current_idx]
labels = build_response_only_next_token_labels(input_ids, mask)
else:
# Pretrain mode: shift labels by 1 for next-token prediction
# Position i predicts token at position i+1
labels = input_ids[1:] + [-100] # Append -100 as final position
# Mark padding tokens as -100 to ignore in loss computation
for i, mask_val in enumerate(attention_mask):
if mask_val == 0:
labels[i] = -100
batch_input_ids.append(input_ids)
batch_attention_masks.append(attention_mask)
batch_labels.append(labels)
self._current_idx += 1
self._samples_yielded += 1
# Return batch if we have any samples, respecting drop_last
if len(batch_input_ids) == 0:
print(f"AsyncBatchIterator: No samples collected for batch. Finished={self._finished}, returning None.")
return None
if len(batch_input_ids) < self.batch_size and self.drop_last:
incomplete_pct = (len(batch_input_ids) / self.batch_size) * 100
print(f"AsyncBatchIterator: Batch incomplete ({len(batch_input_ids)}/{self.batch_size} = {incomplete_pct:.1f}%) and drop_last=True, returning None.")
return None
return self._collate_batch(batch_input_ids, batch_attention_masks, batch_labels)
def _collate_batch(
self,
batch_input_ids: List[List[int]],
batch_attention_masks: List[List[int]],
batch_labels: List[List[int]],
) -> Dict[str, torch.Tensor]:
"""
Collate batch samples and move to device.
Args:
batch_input_ids: List of token ID lists
batch_attention_masks: List of attention mask lists
batch_labels: List of label lists
Returns:
Collated batch as torch tensors on device
"""
# Convert to tensors
input_ids_tensor = torch.tensor(batch_input_ids, dtype=torch.long, device=self.device)
attention_mask_tensor = torch.tensor(batch_attention_masks, dtype=torch.long, device=self.device)
labels_tensor = torch.tensor(batch_labels, dtype=torch.long, device=self.device)
return {
"input_ids": input_ids_tensor,
"attention_mask": attention_mask_tensor,
"labels": labels_tensor,
}
def __len__(self) -> int:
"""Return approximate number of batches."""
total_samples = len(self.queue)
if self.drop_last:
return total_samples // self.batch_size
else:
return (total_samples + self.batch_size - 1) // self.batch_size
def shutdown(self):
"""Shutdown the async iterator and background thread."""
self.queue.shutdown(wait=True)
def __del__(self):
"""Cleanup on deletion."""
self.shutdown()