""" Test script to inspect dataset loading pipeline. Usage: python test_dataset.py config/pretrain.yaml python test_dataset.py config/sft.yaml --num-samples 10 python test_dataset.py config/sft.yaml --sft --num-samples 5 # Show SFT-specific info """ import sys import argparse from pathlib import Path from typing import Optional, Dict import torch from tqdm import tqdm # Add src to path for imports sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from taoTrain.config import load_config, TrainingModeEnum, PretrainConfig from taoTrain.core import create_datasets from taoTrain.data import TokenizationQueue, AsyncBatchIterator from taoTrain.data import BaseJSONLDataset, SFTJSONLDataset, parse_sft_record from taoTrain.utils import set_seed, get_device def print_separator(title: str = "", char: str = "=", width: int = 80): """Print a formatted separator line.""" if title: print(f"\n{char} {title} {char * (width - len(title) - 3)}") else: print(f"\n{char * width}") def format_tensor_info(tensor: torch.Tensor) -> str: """Format tensor information for display.""" return f"shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}" def get_special_token_ids(tokenizer) -> Dict[str, Optional[int]]: """ Extract special token IDs from tokenizer. Returns dict with keys: bos, eos, pad, unk """ special_tokens = { "bos": None, "eos": None, "pad": None, "unk": None, } if tokenizer is None: return special_tokens # Try different ways to access special token IDs based on tokenizer type try: # For SentencePieceTokenizerWrapper if hasattr(tokenizer, 'bos_id'): special_tokens["bos"] = tokenizer.bos_id() if hasattr(tokenizer, 'eos_id'): special_tokens["eos"] = tokenizer.eos_id() if hasattr(tokenizer, 'pad_id'): special_tokens["pad"] = tokenizer.pad_id() if hasattr(tokenizer, 'unk_id'): special_tokens["unk"] = tokenizer.unk_id() # For HuggingFace tokenizers if hasattr(tokenizer, 'bos_token_id'): special_tokens["bos"] = tokenizer.bos_token_id if hasattr(tokenizer, 'eos_token_id'): special_tokens["eos"] = tokenizer.eos_token_id if hasattr(tokenizer, 'pad_token_id'): special_tokens["pad"] = tokenizer.pad_token_id if hasattr(tokenizer, 'unk_token_id'): special_tokens["unk"] = tokenizer.unk_token_id except Exception as e: pass # If extraction fails, keep defaults return special_tokens def count_special_tokens(token_ids: torch.Tensor, special_token_ids: Dict[str, Optional[int]]) -> Dict[str, int]: """ Count occurrences of special tokens in token IDs. Args: token_ids: 1D tensor of token IDs special_token_ids: Dict mapping special token names to IDs Returns: Dict with counts of each special token """ counts = {} # Convert to CPU numpy for efficient counting if isinstance(token_ids, torch.Tensor): ids_numpy = token_ids.cpu().numpy() else: ids_numpy = token_ids for token_name, token_id in special_token_ids.items(): if token_id is not None: count = (ids_numpy == token_id).sum() counts[token_name] = int(count) else: counts[token_name] = 0 return counts def print_sample(sample_idx: int, sample: Dict[str, torch.Tensor], tokenizer=None, special_token_ids: Optional[Dict[str, Optional[int]]] = None, max_display: int = 20): """Print a single sample with formatted information.""" print(f"\n Sample {sample_idx}:") # Input IDs if "input_ids" in sample: input_ids = sample["input_ids"] print(f" input_ids: {format_tensor_info(input_ids)}") print(f" Values: {input_ids[:max_display].tolist()}{'...' if len(input_ids) > max_display else ''}") labels = sample["labels"] print(f" labels: {format_tensor_info(labels)}") print(f" Values: {labels[:max_display].tolist()}{'...' if len(labels) > max_display else ''}") # Count special tokens if special_token_ids is not None: special_counts = count_special_tokens(input_ids, special_token_ids) special_summary = ", ".join([f"{name}={count}" for name, count in special_counts.items()]) print(f" Special tokens: {special_summary}") # Try to decode if tokenizer available if tokenizer is not None: try: decoded = tokenizer.decode(input_ids.tolist()[:max_display], skip_special_tokens=False) preview = decoded.replace('\n', '\\n') print(f" Decoded: {preview}...") except Exception as e: print(f" [Decode error: {e}]") # Attention mask if "attention_mask" in sample: attention_mask = sample["attention_mask"] print(f" attention_mask: {format_tensor_info(attention_mask)}") non_pad_count = attention_mask.sum().item() print(f" Non-padding tokens: {non_pad_count}/{len(attention_mask)}") # Labels if "labels" in sample: labels = sample["labels"] print(f" labels: {format_tensor_info(labels)}") valid_labels = labels[labels != -100] print(f" Valid labels (not -100): {len(valid_labels)}/{len(labels)}") def print_sample_sft(sample_idx: int, sample: Dict[str, torch.Tensor], mask: Optional[list] = None, tokenizer=None, special_token_ids: Optional[Dict[str, Optional[int]]] = None, max_display: int = 20): """ Print a single SFT sample with detailed masking information. Shows: - Which tokens are user input (mask=0, labeled -100) - Which tokens are assistant output (mask=1, labeled with token id) - Token decoding and special token information """ print(f"\n SFT Sample {sample_idx}:") if "input_ids" in sample: input_ids = sample["input_ids"] print(f" input_ids: {format_tensor_info(input_ids)}") labels = sample["labels"] print(f" labels: {format_tensor_info(labels)}") # Show mask breakdown if mask is not None: mask_array = mask user_count = sum(1 for m in mask_array if m == 0) assistant_count = sum(1 for m in mask_array if m == 1) print(f" Mask breakdown:") print(f" - User input tokens (mask=0, ignored): {user_count}") print(f" - Assistant output tokens (mask=1, trained): {assistant_count}") # Show which regions are which print(f" Token regions (first {max_display} tokens):") for i in range(min(max_display, len(input_ids))): token_id = input_ids[i].item() label = labels[i].item() mask_val = mask_array[i] if i < len(mask_array) else 0 region_type = "USER " if mask_val == 0 else "ASST " label_str = "IGNORE" if label == -100 else f"{label:5d}" # Try to decode token token_str = "?" if tokenizer is not None: try: token_str = tokenizer.decode([token_id], skip_special_tokens=False).replace('\n', '\\n')[:15] except: token_str = f"ID:{token_id}" print(f" [{i:3d}] {region_type} | label={label_str} | token={token_str}") if len(input_ids) > max_display: print(f" ... and {len(input_ids) - max_display} more tokens") else: # Fallback: just show labels masked_count = (labels == -100).sum().item() valid_count = (labels != -100).sum().item() print(f" Token labels breakdown:") print(f" - Masked tokens (label=-100): {masked_count}") print(f" - Training tokens: {valid_count}") def inspect_dataset(config_path: str, num_samples: int = 10, max_display: int = 20, show_sft: bool = False): """ Main inspection function. Args: config_path: Path to YAML config file num_samples: Number of samples to display (batches or individual samples) max_display: Max tokens to display in preview show_sft: If True, show SFT-specific masking information (requires SFT dataset) """ # ======================================================================== # Step 1: Load and validate config # ======================================================================== print_separator("STEP 1: LOAD CONFIGURATION") config_path = Path(config_path) if not config_path.exists(): print(f"✗ Config file not found: {config_path}") sys.exit(1) print(f"✓ Loading config from: {config_path}") # Try to load config - auto-detect mode or use default try: # Try to infer mode from filename or use default if "pretrain" in str(config_path).lower(): mode = TrainingModeEnum.PRETRAIN elif "sft" in str(config_path).lower(): mode = TrainingModeEnum.SFT elif "rl" in str(config_path).lower() or "dpo" in str(config_path).lower(): mode = TrainingModeEnum.RL else: mode = TrainingModeEnum.PRETRAIN # Default config = load_config(config_path, mode) print(f"✓ Config loaded (mode: {config.mode.value})") except Exception as e: print(f"✗ Failed to load config: {e}") sys.exit(1) # Print config summary print(f"\nConfiguration Summary:") print(f" - Mode: {config.mode.value}") print(f" - Model: {config.model.architecture_type.value}") print(f" - Vocab size: {config.model.vocab_size}") print(f" - Max seq length: {config.model.max_seq_length}") print(f" - Batch size: {config.batch_size}") print(f" - Dataset source: {'Local JSONL' if config.dataset.local else 'HuggingFace'}") if config.dataset.local: print(f" - JSONL path: {config.dataset.jsonl_path}") print(f" - Tokenizer: {config.dataset.tokenizer_type or 'auto-detect'}") if config.dataset.tokenizer_path: print(f" - Tokenizer path: {config.dataset.tokenizer_path}") print(f" - Samples per chunk: {config.dataset.samples_per_chunk}") print(f" - Tokenizer threads: {config.dataset.tokenizer_threads}") else: print(f" - Dataset name: {config.dataset.dataset_name}") print(f" - Dataset config: {config.dataset.config or 'default'}") print(f" - Split: {config.dataset.split}") if config.dataset.max_samples: print(f" - Max samples (limit): {config.dataset.max_samples}") # Print SFT-specific config if applicable if show_sft and hasattr(config, 'response_loss_only'): print(f"\n SFT Configuration:") print(f" - Response loss only: {config.response_loss_only}") user_token = getattr(config, 'user_token', '') assistant_token = getattr(config, 'assistant_token', '') print(f" - User token: {user_token}") print(f" - Assistant token: {assistant_token}") checkpoint_path = getattr(config, 'checkpoint_path', None) if checkpoint_path: print(f" - Checkpoint path: {checkpoint_path}") # ======================================================================== # Step 2: Setup device and random seed # ======================================================================== print_separator("STEP 2: SETUP DEVICE AND SEED") device = get_device(config.device) print(f"✓ Device: {device}") print(f"✓ CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f" - GPU: {torch.cuda.get_device_name(0)}") print(f" - CUDA version: {torch.version.cuda}") set_seed(config.seed) print(f"✓ Random seed set: {config.seed}") # ======================================================================== # Step 3: Create datasets # ======================================================================== print_separator("STEP 3: CREATE DATASETS") try: train_dataset, val_dataset = create_datasets(config) print(f"✓ Train dataset created: {type(train_dataset).__name__}") if val_dataset: print(f"✓ Validation dataset created: {type(val_dataset).__name__}") else: print(f"✓ No validation dataset (JSONL or not configured)") except Exception as e: print(f"✗ Failed to create datasets: {e}") import traceback traceback.print_exc() sys.exit(1) print(f" - Train dataset length: {len(train_dataset)}") # ======================================================================== # Step 4: Setup async pipeline (for JSONL datasets) # ======================================================================== async_loader = None tokenizer = None special_token_ids = None if isinstance(train_dataset, BaseJSONLDataset): print_separator("STEP 4: SETUP ASYNC PIPELINE") # Extract components chunk_manager = train_dataset.chunk_manager tokenizer = train_dataset.tokenizer # Get special token IDs special_token_ids = get_special_token_ids(tokenizer) print(f"✓ Special token IDs extracted:") for token_name, token_id in special_token_ids.items(): print(f" - {token_name.upper()}: {token_id}") print(f"✓ ChunkManager found:") print(f" - Total chunks: {chunk_manager.num_chunks}") print(f" - Effective lines: {chunk_manager.effective_lines}") print(f" - File size: {chunk_manager.file_size_bytes / (1024**2):.1f} MB") print(f" - Chunk line ranges: {chunk_manager.chunk_line_ranges[:3]}..." if len(chunk_manager.chunk_line_ranges) > 3 else f" - Chunk line ranges: {chunk_manager.chunk_line_ranges}") print(f"✓ Tokenizer found: {type(tokenizer).__name__}") # Create tokenization queue print(f"✓ Creating TokenizationQueue...") try: tokenization_queue = TokenizationQueue( chunk_manager=chunk_manager, tokenizer=tokenizer, config=config, max_queue_size=4, shuffle_chunks=True, num_threads=config.dataset.tokenizer_threads, ) print(f" - Max queue size: 4") print(f" - Tokenizer threads: {config.dataset.tokenizer_threads}") print(f" - Total items: {tokenization_queue.total_items}") except Exception as e: print(f"✗ Failed to create TokenizationQueue: {e}") import traceback traceback.print_exc() sys.exit(1) # Create async batch iterator print(f"✓ Creating AsyncBatchIterator...") try: async_loader = AsyncBatchIterator( tokenization_queue=tokenization_queue, batch_size=config.batch_size, device=device, drop_last=True, gradient_accumulation_steps=config.gradient_accumulation_steps, ) print(f" - Batch size: {config.batch_size}") print(f" - Device: {device}") print(f" - Gradient accumulation: {config.gradient_accumulation_steps}") except Exception as e: print(f"✗ Failed to create AsyncBatchIterator: {e}") import traceback traceback.print_exc() sys.exit(1) else: print_separator("STEP 4: USING STANDARD DATALOADER") from taoTrain.data.loaders import get_dataloader tokenizer = getattr(train_dataset, "tokenizer", None) # Get special token IDs special_token_ids = get_special_token_ids(tokenizer) if tokenizer is not None: print(f"✓ Special token IDs extracted:") for token_name, token_id in special_token_ids.items(): print(f" - {token_name.upper()}: {token_id}") print(f"✓ HuggingFace dataset detected") print(f"✓ Will use standard DataLoader (batch_size={config.batch_size})") # Create standard dataloader try: train_loader = get_dataloader( train_dataset, config, shuffle=False, drop_last=False, ) async_loader = train_loader print(f"✓ DataLoader created") except Exception as e: print(f"✗ Failed to create DataLoader: {e}") import traceback traceback.print_exc() sys.exit(1) # ======================================================================== # Step 5: Fetch and display samples # ======================================================================== print_separator(f"STEP 5: FETCH FIRST {num_samples} BATCHES") if async_loader is None: print(f"✗ No data loader available") sys.exit(1) total_samples = 0 total_tokens = 0 batch_idx = -1 # Initialize to track if any batches were processed cumulative_special_counts = {"bos": 0, "eos": 0, "pad": 0, "unk": 0} try: for batch_idx, batch in enumerate(tqdm(async_loader, total=num_samples, desc="Fetching batches")): if batch_idx >= num_samples: break print_separator(f"BATCH {batch_idx}", char="-", width=60) print(f"Batch keys: {batch.keys()}") # Print batch shapes for key, tensor in batch.items(): print(f" {key}: {format_tensor_info(tensor)}") # Count tokens and special tokens if "input_ids" in batch: batch_token_count = batch["input_ids"].numel() total_tokens += batch_token_count print(f" Total tokens in batch: {batch_token_count}") # Count special tokens in entire batch if special_token_ids is not None: batch_input_ids = batch["input_ids"].view(-1) # Flatten batch batch_special_counts = count_special_tokens(batch_input_ids, special_token_ids) special_str = ", ".join([f"{name}={count}" for name, count in batch_special_counts.items()]) print(f" Batch special tokens: {special_str}") # Accumulate counts for token_name, count in batch_special_counts.items(): cumulative_special_counts[token_name] += count # Print individual samples in batch batch_size = next(iter(batch.values())).shape[0] total_samples += batch_size print(f"\nSamples in batch (showing first 3 of {batch_size}):") for sample_idx in range(min(3, batch_size)): sample = {key: tensor[sample_idx] for key, tensor in batch.items()} # For SFT datasets, try to get mask info mask = None if show_sft and isinstance(train_dataset, SFTJSONLDataset): try: # Access the mask from the dataset's current chunk if (hasattr(train_dataset, '_current_chunk_data') and train_dataset._current_chunk_data is not None and "mask" in train_dataset._current_chunk_data): # We need to map from batch sample index back to dataset index # For simplicity, access the local chunk index chunk_idx = batch_idx * batch_size + sample_idx if chunk_idx < len(train_dataset._current_chunk_data["mask"]): mask = train_dataset._current_chunk_data["mask"][sample_idx] except Exception as e: print(f" [Could not access mask: {e}]") # Use SFT-specific print function print_sample_sft(batch_idx * batch_size + sample_idx, sample, mask, tokenizer, special_token_ids, max_display) else: # Use regular print function print_sample(batch_idx * batch_size + sample_idx, sample, tokenizer, special_token_ids, max_display) if batch_size > 3: print(f" ... and {batch_size - 3} more samples") except Exception as e: print(f"✗ Error during batch fetching: {e}") import traceback traceback.print_exc() sys.exit(1) # ======================================================================== # Summary Statistics # ======================================================================== print_separator("SUMMARY STATISTICS") print(f"Total batches fetched: {min(batch_idx + 1, num_samples)}") print(f"Total samples: {total_samples}") print(f"Total tokens: {total_tokens}") if total_samples > 0: print(f"Average tokens per sample: {total_tokens / total_samples:.1f}") # Print special token summary if special_token_ids is not None: print(f"\nSpecial token counts across all batches:") for token_name, token_id in special_token_ids.items(): total_count = cumulative_special_counts[token_name] pct = (total_count / total_tokens * 100) if total_tokens > 0 else 0 print(f" - {token_name.upper()}: {total_count} tokens ({pct:.2f}%)") print_separator() print("✓ Inspection complete!") def main(): """Main entry point with argument parsing.""" parser = argparse.ArgumentParser( description="Test dataset loading pipeline - inspect data shapes, tokenization, and samples" ) parser.add_argument( "config_path", type=str, help="Path to YAML/JSON config file (e.g., config/pretrain.yaml)" ) parser.add_argument( "--num-samples", type=int, default=10, help="Number of batches to display (default: 10)" ) parser.add_argument( "--max-display", type=int, default=20, help="Maximum tokens to display in preview (default: 20)" ) parser.add_argument( "--sft", action="store_true", help="If set, show SFT-specific masking information (user vs assistant tokens)" ) args = parser.parse_args() inspect_dataset( args.config_path, num_samples=args.num_samples, max_display=args.max_display, show_sft=args.sft ) if __name__ == "__main__": main()