TaoNet-mini-T2 / code /TaoTrain /tests /test_dataset.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""
Test script to inspect dataset loading pipeline.
Usage:
python test_dataset.py config/pretrain.yaml
python test_dataset.py config/sft.yaml --num-samples 10
python test_dataset.py config/sft.yaml --sft --num-samples 5 # Show SFT-specific info
"""
import sys
import argparse
from pathlib import Path
from typing import Optional, Dict
import torch
from tqdm import tqdm
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from taoTrain.config import load_config, TrainingModeEnum, PretrainConfig
from taoTrain.core import create_datasets
from taoTrain.data import TokenizationQueue, AsyncBatchIterator
from taoTrain.data import BaseJSONLDataset, SFTJSONLDataset, parse_sft_record
from taoTrain.utils import set_seed, get_device
def print_separator(title: str = "", char: str = "=", width: int = 80):
"""Print a formatted separator line."""
if title:
print(f"\n{char} {title} {char * (width - len(title) - 3)}")
else:
print(f"\n{char * width}")
def format_tensor_info(tensor: torch.Tensor) -> str:
"""Format tensor information for display."""
return f"shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}"
def get_special_token_ids(tokenizer) -> Dict[str, Optional[int]]:
"""
Extract special token IDs from tokenizer.
Returns dict with keys: bos, eos, pad, unk
"""
special_tokens = {
"bos": None,
"eos": None,
"pad": None,
"unk": None,
}
if tokenizer is None:
return special_tokens
# Try different ways to access special token IDs based on tokenizer type
try:
# For SentencePieceTokenizerWrapper
if hasattr(tokenizer, 'bos_id'):
special_tokens["bos"] = tokenizer.bos_id()
if hasattr(tokenizer, 'eos_id'):
special_tokens["eos"] = tokenizer.eos_id()
if hasattr(tokenizer, 'pad_id'):
special_tokens["pad"] = tokenizer.pad_id()
if hasattr(tokenizer, 'unk_id'):
special_tokens["unk"] = tokenizer.unk_id()
# For HuggingFace tokenizers
if hasattr(tokenizer, 'bos_token_id'):
special_tokens["bos"] = tokenizer.bos_token_id
if hasattr(tokenizer, 'eos_token_id'):
special_tokens["eos"] = tokenizer.eos_token_id
if hasattr(tokenizer, 'pad_token_id'):
special_tokens["pad"] = tokenizer.pad_token_id
if hasattr(tokenizer, 'unk_token_id'):
special_tokens["unk"] = tokenizer.unk_token_id
except Exception as e:
pass # If extraction fails, keep defaults
return special_tokens
def count_special_tokens(token_ids: torch.Tensor, special_token_ids: Dict[str, Optional[int]]) -> Dict[str, int]:
"""
Count occurrences of special tokens in token IDs.
Args:
token_ids: 1D tensor of token IDs
special_token_ids: Dict mapping special token names to IDs
Returns:
Dict with counts of each special token
"""
counts = {}
# Convert to CPU numpy for efficient counting
if isinstance(token_ids, torch.Tensor):
ids_numpy = token_ids.cpu().numpy()
else:
ids_numpy = token_ids
for token_name, token_id in special_token_ids.items():
if token_id is not None:
count = (ids_numpy == token_id).sum()
counts[token_name] = int(count)
else:
counts[token_name] = 0
return counts
def print_sample(sample_idx: int, sample: Dict[str, torch.Tensor], tokenizer=None,
special_token_ids: Optional[Dict[str, Optional[int]]] = None, max_display: int = 20):
"""Print a single sample with formatted information."""
print(f"\n Sample {sample_idx}:")
# Input IDs
if "input_ids" in sample:
input_ids = sample["input_ids"]
print(f" input_ids: {format_tensor_info(input_ids)}")
print(f" Values: {input_ids[:max_display].tolist()}{'...' if len(input_ids) > max_display else ''}")
labels = sample["labels"]
print(f" labels: {format_tensor_info(labels)}")
print(f" Values: {labels[:max_display].tolist()}{'...' if len(labels) > max_display else ''}")
# Count special tokens
if special_token_ids is not None:
special_counts = count_special_tokens(input_ids, special_token_ids)
special_summary = ", ".join([f"{name}={count}" for name, count in special_counts.items()])
print(f" Special tokens: {special_summary}")
# Try to decode if tokenizer available
if tokenizer is not None:
try:
decoded = tokenizer.decode(input_ids.tolist()[:max_display], skip_special_tokens=False)
preview = decoded.replace('\n', '\\n')
print(f" Decoded: {preview}...")
except Exception as e:
print(f" [Decode error: {e}]")
# Attention mask
if "attention_mask" in sample:
attention_mask = sample["attention_mask"]
print(f" attention_mask: {format_tensor_info(attention_mask)}")
non_pad_count = attention_mask.sum().item()
print(f" Non-padding tokens: {non_pad_count}/{len(attention_mask)}")
# Labels
if "labels" in sample:
labels = sample["labels"]
print(f" labels: {format_tensor_info(labels)}")
valid_labels = labels[labels != -100]
print(f" Valid labels (not -100): {len(valid_labels)}/{len(labels)}")
def print_sample_sft(sample_idx: int, sample: Dict[str, torch.Tensor], mask: Optional[list] = None,
tokenizer=None, special_token_ids: Optional[Dict[str, Optional[int]]] = None,
max_display: int = 20):
"""
Print a single SFT sample with detailed masking information.
Shows:
- Which tokens are user input (mask=0, labeled -100)
- Which tokens are assistant output (mask=1, labeled with token id)
- Token decoding and special token information
"""
print(f"\n SFT Sample {sample_idx}:")
if "input_ids" in sample:
input_ids = sample["input_ids"]
print(f" input_ids: {format_tensor_info(input_ids)}")
labels = sample["labels"]
print(f" labels: {format_tensor_info(labels)}")
# Show mask breakdown
if mask is not None:
mask_array = mask
user_count = sum(1 for m in mask_array if m == 0)
assistant_count = sum(1 for m in mask_array if m == 1)
print(f" Mask breakdown:")
print(f" - User input tokens (mask=0, ignored): {user_count}")
print(f" - Assistant output tokens (mask=1, trained): {assistant_count}")
# Show which regions are which
print(f" Token regions (first {max_display} tokens):")
for i in range(min(max_display, len(input_ids))):
token_id = input_ids[i].item()
label = labels[i].item()
mask_val = mask_array[i] if i < len(mask_array) else 0
region_type = "USER " if mask_val == 0 else "ASST "
label_str = "IGNORE" if label == -100 else f"{label:5d}"
# Try to decode token
token_str = "?"
if tokenizer is not None:
try:
token_str = tokenizer.decode([token_id], skip_special_tokens=False).replace('\n', '\\n')[:15]
except:
token_str = f"ID:{token_id}"
print(f" [{i:3d}] {region_type} | label={label_str} | token={token_str}")
if len(input_ids) > max_display:
print(f" ... and {len(input_ids) - max_display} more tokens")
else:
# Fallback: just show labels
masked_count = (labels == -100).sum().item()
valid_count = (labels != -100).sum().item()
print(f" Token labels breakdown:")
print(f" - Masked tokens (label=-100): {masked_count}")
print(f" - Training tokens: {valid_count}")
def inspect_dataset(config_path: str, num_samples: int = 10, max_display: int = 20, show_sft: bool = False):
"""
Main inspection function.
Args:
config_path: Path to YAML config file
num_samples: Number of samples to display (batches or individual samples)
max_display: Max tokens to display in preview
show_sft: If True, show SFT-specific masking information (requires SFT dataset)
"""
# ========================================================================
# Step 1: Load and validate config
# ========================================================================
print_separator("STEP 1: LOAD CONFIGURATION")
config_path = Path(config_path)
if not config_path.exists():
print(f"✗ Config file not found: {config_path}")
sys.exit(1)
print(f"✓ Loading config from: {config_path}")
# Try to load config - auto-detect mode or use default
try:
# Try to infer mode from filename or use default
if "pretrain" in str(config_path).lower():
mode = TrainingModeEnum.PRETRAIN
elif "sft" in str(config_path).lower():
mode = TrainingModeEnum.SFT
elif "rl" in str(config_path).lower() or "dpo" in str(config_path).lower():
mode = TrainingModeEnum.RL
else:
mode = TrainingModeEnum.PRETRAIN # Default
config = load_config(config_path, mode)
print(f"✓ Config loaded (mode: {config.mode.value})")
except Exception as e:
print(f"✗ Failed to load config: {e}")
sys.exit(1)
# Print config summary
print(f"\nConfiguration Summary:")
print(f" - Mode: {config.mode.value}")
print(f" - Model: {config.model.architecture_type.value}")
print(f" - Vocab size: {config.model.vocab_size}")
print(f" - Max seq length: {config.model.max_seq_length}")
print(f" - Batch size: {config.batch_size}")
print(f" - Dataset source: {'Local JSONL' if config.dataset.local else 'HuggingFace'}")
if config.dataset.local:
print(f" - JSONL path: {config.dataset.jsonl_path}")
print(f" - Tokenizer: {config.dataset.tokenizer_type or 'auto-detect'}")
if config.dataset.tokenizer_path:
print(f" - Tokenizer path: {config.dataset.tokenizer_path}")
print(f" - Samples per chunk: {config.dataset.samples_per_chunk}")
print(f" - Tokenizer threads: {config.dataset.tokenizer_threads}")
else:
print(f" - Dataset name: {config.dataset.dataset_name}")
print(f" - Dataset config: {config.dataset.config or 'default'}")
print(f" - Split: {config.dataset.split}")
if config.dataset.max_samples:
print(f" - Max samples (limit): {config.dataset.max_samples}")
# Print SFT-specific config if applicable
if show_sft and hasattr(config, 'response_loss_only'):
print(f"\n SFT Configuration:")
print(f" - Response loss only: {config.response_loss_only}")
user_token = getattr(config, 'user_token', '<user>')
assistant_token = getattr(config, 'assistant_token', '<assistant>')
print(f" - User token: {user_token}")
print(f" - Assistant token: {assistant_token}")
checkpoint_path = getattr(config, 'checkpoint_path', None)
if checkpoint_path:
print(f" - Checkpoint path: {checkpoint_path}")
# ========================================================================
# Step 2: Setup device and random seed
# ========================================================================
print_separator("STEP 2: SETUP DEVICE AND SEED")
device = get_device(config.device)
print(f"✓ Device: {device}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" - GPU: {torch.cuda.get_device_name(0)}")
print(f" - CUDA version: {torch.version.cuda}")
set_seed(config.seed)
print(f"✓ Random seed set: {config.seed}")
# ========================================================================
# Step 3: Create datasets
# ========================================================================
print_separator("STEP 3: CREATE DATASETS")
try:
train_dataset, val_dataset = create_datasets(config)
print(f"✓ Train dataset created: {type(train_dataset).__name__}")
if val_dataset:
print(f"✓ Validation dataset created: {type(val_dataset).__name__}")
else:
print(f"✓ No validation dataset (JSONL or not configured)")
except Exception as e:
print(f"✗ Failed to create datasets: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
print(f" - Train dataset length: {len(train_dataset)}")
# ========================================================================
# Step 4: Setup async pipeline (for JSONL datasets)
# ========================================================================
async_loader = None
tokenizer = None
special_token_ids = None
if isinstance(train_dataset, BaseJSONLDataset):
print_separator("STEP 4: SETUP ASYNC PIPELINE")
# Extract components
chunk_manager = train_dataset.chunk_manager
tokenizer = train_dataset.tokenizer
# Get special token IDs
special_token_ids = get_special_token_ids(tokenizer)
print(f"✓ Special token IDs extracted:")
for token_name, token_id in special_token_ids.items():
print(f" - {token_name.upper()}: {token_id}")
print(f"✓ ChunkManager found:")
print(f" - Total chunks: {chunk_manager.num_chunks}")
print(f" - Effective lines: {chunk_manager.effective_lines}")
print(f" - File size: {chunk_manager.file_size_bytes / (1024**2):.1f} MB")
print(f" - Chunk line ranges: {chunk_manager.chunk_line_ranges[:3]}..." if len(chunk_manager.chunk_line_ranges) > 3 else f" - Chunk line ranges: {chunk_manager.chunk_line_ranges}")
print(f"✓ Tokenizer found: {type(tokenizer).__name__}")
# Create tokenization queue
print(f"✓ Creating TokenizationQueue...")
try:
tokenization_queue = TokenizationQueue(
chunk_manager=chunk_manager,
tokenizer=tokenizer,
config=config,
max_queue_size=4,
shuffle_chunks=True,
num_threads=config.dataset.tokenizer_threads,
)
print(f" - Max queue size: 4")
print(f" - Tokenizer threads: {config.dataset.tokenizer_threads}")
print(f" - Total items: {tokenization_queue.total_items}")
except Exception as e:
print(f"✗ Failed to create TokenizationQueue: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Create async batch iterator
print(f"✓ Creating AsyncBatchIterator...")
try:
async_loader = AsyncBatchIterator(
tokenization_queue=tokenization_queue,
batch_size=config.batch_size,
device=device,
drop_last=True,
gradient_accumulation_steps=config.gradient_accumulation_steps,
)
print(f" - Batch size: {config.batch_size}")
print(f" - Device: {device}")
print(f" - Gradient accumulation: {config.gradient_accumulation_steps}")
except Exception as e:
print(f"✗ Failed to create AsyncBatchIterator: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
else:
print_separator("STEP 4: USING STANDARD DATALOADER")
from taoTrain.data.loaders import get_dataloader
tokenizer = getattr(train_dataset, "tokenizer", None)
# Get special token IDs
special_token_ids = get_special_token_ids(tokenizer)
if tokenizer is not None:
print(f"✓ Special token IDs extracted:")
for token_name, token_id in special_token_ids.items():
print(f" - {token_name.upper()}: {token_id}")
print(f"✓ HuggingFace dataset detected")
print(f"✓ Will use standard DataLoader (batch_size={config.batch_size})")
# Create standard dataloader
try:
train_loader = get_dataloader(
train_dataset,
config,
shuffle=False,
drop_last=False,
)
async_loader = train_loader
print(f"✓ DataLoader created")
except Exception as e:
print(f"✗ Failed to create DataLoader: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# ========================================================================
# Step 5: Fetch and display samples
# ========================================================================
print_separator(f"STEP 5: FETCH FIRST {num_samples} BATCHES")
if async_loader is None:
print(f"✗ No data loader available")
sys.exit(1)
total_samples = 0
total_tokens = 0
batch_idx = -1 # Initialize to track if any batches were processed
cumulative_special_counts = {"bos": 0, "eos": 0, "pad": 0, "unk": 0}
try:
for batch_idx, batch in enumerate(tqdm(async_loader, total=num_samples, desc="Fetching batches")):
if batch_idx >= num_samples:
break
print_separator(f"BATCH {batch_idx}", char="-", width=60)
print(f"Batch keys: {batch.keys()}")
# Print batch shapes
for key, tensor in batch.items():
print(f" {key}: {format_tensor_info(tensor)}")
# Count tokens and special tokens
if "input_ids" in batch:
batch_token_count = batch["input_ids"].numel()
total_tokens += batch_token_count
print(f" Total tokens in batch: {batch_token_count}")
# Count special tokens in entire batch
if special_token_ids is not None:
batch_input_ids = batch["input_ids"].view(-1) # Flatten batch
batch_special_counts = count_special_tokens(batch_input_ids, special_token_ids)
special_str = ", ".join([f"{name}={count}" for name, count in batch_special_counts.items()])
print(f" Batch special tokens: {special_str}")
# Accumulate counts
for token_name, count in batch_special_counts.items():
cumulative_special_counts[token_name] += count
# Print individual samples in batch
batch_size = next(iter(batch.values())).shape[0]
total_samples += batch_size
print(f"\nSamples in batch (showing first 3 of {batch_size}):")
for sample_idx in range(min(3, batch_size)):
sample = {key: tensor[sample_idx] for key, tensor in batch.items()}
# For SFT datasets, try to get mask info
mask = None
if show_sft and isinstance(train_dataset, SFTJSONLDataset):
try:
# Access the mask from the dataset's current chunk
if (hasattr(train_dataset, '_current_chunk_data') and
train_dataset._current_chunk_data is not None and
"mask" in train_dataset._current_chunk_data):
# We need to map from batch sample index back to dataset index
# For simplicity, access the local chunk index
chunk_idx = batch_idx * batch_size + sample_idx
if chunk_idx < len(train_dataset._current_chunk_data["mask"]):
mask = train_dataset._current_chunk_data["mask"][sample_idx]
except Exception as e:
print(f" [Could not access mask: {e}]")
# Use SFT-specific print function
print_sample_sft(batch_idx * batch_size + sample_idx, sample, mask, tokenizer, special_token_ids, max_display)
else:
# Use regular print function
print_sample(batch_idx * batch_size + sample_idx, sample, tokenizer, special_token_ids, max_display)
if batch_size > 3:
print(f" ... and {batch_size - 3} more samples")
except Exception as e:
print(f"✗ Error during batch fetching: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# ========================================================================
# Summary Statistics
# ========================================================================
print_separator("SUMMARY STATISTICS")
print(f"Total batches fetched: {min(batch_idx + 1, num_samples)}")
print(f"Total samples: {total_samples}")
print(f"Total tokens: {total_tokens}")
if total_samples > 0:
print(f"Average tokens per sample: {total_tokens / total_samples:.1f}")
# Print special token summary
if special_token_ids is not None:
print(f"\nSpecial token counts across all batches:")
for token_name, token_id in special_token_ids.items():
total_count = cumulative_special_counts[token_name]
pct = (total_count / total_tokens * 100) if total_tokens > 0 else 0
print(f" - {token_name.upper()}: {total_count} tokens ({pct:.2f}%)")
print_separator()
print("✓ Inspection complete!")
def main():
"""Main entry point with argument parsing."""
parser = argparse.ArgumentParser(
description="Test dataset loading pipeline - inspect data shapes, tokenization, and samples"
)
parser.add_argument(
"config_path",
type=str,
help="Path to YAML/JSON config file (e.g., config/pretrain.yaml)"
)
parser.add_argument(
"--num-samples",
type=int,
default=10,
help="Number of batches to display (default: 10)"
)
parser.add_argument(
"--max-display",
type=int,
default=20,
help="Maximum tokens to display in preview (default: 20)"
)
parser.add_argument(
"--sft",
action="store_true",
help="If set, show SFT-specific masking information (user vs assistant tokens)"
)
args = parser.parse_args()
inspect_dataset(
args.config_path,
num_samples=args.num_samples,
max_display=args.max_display,
show_sft=args.sft
)
if __name__ == "__main__":
main()