"""Async batch iterator for training with background tokenization.""" from typing import Dict, List, Optional, Any, Iterator import torch from taoTrain.data.tokenization_queue import TokenizationQueue from taoTrain.data.sft_utils import build_response_only_next_token_labels class AsyncBatchIterator: """ Iterator that yields batches from a tokenization queue. This allows batches to be consumed directly from the background tokenization thread without waiting for all chunks to be tokenized upfront. The iterator: 1. Pulls pre-tokenized chunks from the TokenizationQueue 2. Yields individual samples or batches 3. Handles movement to device (GPU/CPU) at batch level 4. Supports gradient accumulation """ def __init__( self, tokenization_queue: TokenizationQueue, batch_size: int, device: torch.device, drop_last: bool = True, gradient_accumulation_steps: int = 1, ): """ Initialize async batch iterator. Args: tokenization_queue: TokenizationQueue instance batch_size: Batch size for yielding batches device: torch.device to move batches to drop_last: If True, drop last incomplete batch gradient_accumulation_steps: For logging purposes (not used here) """ self.queue = tokenization_queue self.batch_size = batch_size self.device = device self.drop_last = drop_last self.gradient_accumulation_steps = gradient_accumulation_steps # State for iteration self._current_chunk: Optional[Dict[str, List]] = None self._current_idx = 0 self._samples_yielded = 0 self._finished = False def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return iterator (self).""" # Reset state for new epoch self._current_chunk = None self._current_idx = 0 self._samples_yielded = 0 self._finished = False # Reset tokenization queue for epochs 2+ if self.queue._next_chunk_idx > 0: print(f"\n✓ Resetting TokenizationQueue for next epoch (cur_idx={self.queue._next_chunk_idx})") self.queue.reset_for_next_epoch() # Start tokenization threads once per iterator creation if not self.queue._threads: print("\n✓ Starting TokenizationQueue worker threads...") self.queue.start() else: print(f"\n⚠ TokenizationQueue threads already running: {len(self.queue._threads)} active") return self def __next__(self) -> Dict[str, torch.Tensor]: """ Get next batch. Yields: Dict with 'input_ids', 'attention_mask', 'labels' (all as torch tensors on device) Raises: StopIteration: When no more batches available """ batch = self._get_next_batch() if batch is None: print("AsyncBatchIterator: No more batches available, stopping iteration.") raise StopIteration return batch def _get_next_batch(self) -> Optional[Dict[str, torch.Tensor]]: """ Fetch and collate the next batch. Returns: Dict with batch tensors, or None if iteration exhausted """ batch_input_ids = [] batch_attention_masks = [] batch_labels = [] while len(batch_input_ids) < self.batch_size: # Try to get next sample from current chunk if self._current_chunk is None or self._current_idx >= len(self._current_chunk["input_ids"]): # Need new chunk self._current_chunk = self.queue.get_next_chunk(timeout=30.0) # 30s polling timeout if self._current_chunk is None: if not self.queue.is_exhausted: continue # Queue exhausted chunk_count = self.queue._next_chunk_idx if hasattr(self.queue, '_next_chunk_idx') else 'unknown' print(f"AsyncBatchIterator: No more chunks (processed {chunk_count}/{len(self.queue._chunk_order)})") print(f"AsyncBatchIterator: Samples yielded so far: {self._samples_yielded}") self._finished = True break self._current_idx = 0 # Get sample from current chunk input_ids = self._current_chunk["input_ids"][self._current_idx] attention_mask = self._current_chunk["attention_mask"][self._current_idx] # Generate labels based on SFT or pretrain mode if "mask" in self._current_chunk: # SFT mode: use mask to determine which tokens to train on # mask=0 → label=-100 (ignore), mask=1 → label=input_id (train on) mask = self._current_chunk["mask"][self._current_idx] labels = build_response_only_next_token_labels(input_ids, mask) else: # Pretrain mode: shift labels by 1 for next-token prediction # Position i predicts token at position i+1 labels = input_ids[1:] + [-100] # Append -100 as final position # Mark padding tokens as -100 to ignore in loss computation for i, mask_val in enumerate(attention_mask): if mask_val == 0: labels[i] = -100 batch_input_ids.append(input_ids) batch_attention_masks.append(attention_mask) batch_labels.append(labels) self._current_idx += 1 self._samples_yielded += 1 # Return batch if we have any samples, respecting drop_last if len(batch_input_ids) == 0: print(f"AsyncBatchIterator: No samples collected for batch. Finished={self._finished}, returning None.") return None if len(batch_input_ids) < self.batch_size and self.drop_last: incomplete_pct = (len(batch_input_ids) / self.batch_size) * 100 print(f"AsyncBatchIterator: Batch incomplete ({len(batch_input_ids)}/{self.batch_size} = {incomplete_pct:.1f}%) and drop_last=True, returning None.") return None return self._collate_batch(batch_input_ids, batch_attention_masks, batch_labels) def _collate_batch( self, batch_input_ids: List[List[int]], batch_attention_masks: List[List[int]], batch_labels: List[List[int]], ) -> Dict[str, torch.Tensor]: """ Collate batch samples and move to device. Args: batch_input_ids: List of token ID lists batch_attention_masks: List of attention mask lists batch_labels: List of label lists Returns: Collated batch as torch tensors on device """ # Convert to tensors input_ids_tensor = torch.tensor(batch_input_ids, dtype=torch.long, device=self.device) attention_mask_tensor = torch.tensor(batch_attention_masks, dtype=torch.long, device=self.device) labels_tensor = torch.tensor(batch_labels, dtype=torch.long, device=self.device) return { "input_ids": input_ids_tensor, "attention_mask": attention_mask_tensor, "labels": labels_tensor, } def __len__(self) -> int: """Return approximate number of batches.""" total_samples = len(self.queue) if self.drop_last: return total_samples // self.batch_size else: return (total_samples + self.batch_size - 1) // self.batch_size def shutdown(self): """Shutdown the async iterator and background thread.""" self.queue.shutdown(wait=True) def __del__(self): """Cleanup on deletion.""" self.shutdown()