Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """ | |
| Test script to inspect dataset loading pipeline. | |
| Usage: | |
| python test_dataset.py config/pretrain.yaml | |
| python test_dataset.py config/sft.yaml --num-samples 10 | |
| python test_dataset.py config/sft.yaml --sft --num-samples 5 # Show SFT-specific info | |
| """ | |
| import sys | |
| import argparse | |
| from pathlib import Path | |
| from typing import Optional, Dict | |
| import torch | |
| from tqdm import tqdm | |
| # Add src to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) | |
| from taoTrain.config import load_config, TrainingModeEnum, PretrainConfig | |
| from taoTrain.core import create_datasets | |
| from taoTrain.data import TokenizationQueue, AsyncBatchIterator | |
| from taoTrain.data import BaseJSONLDataset, SFTJSONLDataset, parse_sft_record | |
| from taoTrain.utils import set_seed, get_device | |
| def print_separator(title: str = "", char: str = "=", width: int = 80): | |
| """Print a formatted separator line.""" | |
| if title: | |
| print(f"\n{char} {title} {char * (width - len(title) - 3)}") | |
| else: | |
| print(f"\n{char * width}") | |
| def format_tensor_info(tensor: torch.Tensor) -> str: | |
| """Format tensor information for display.""" | |
| return f"shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}" | |
| def get_special_token_ids(tokenizer) -> Dict[str, Optional[int]]: | |
| """ | |
| Extract special token IDs from tokenizer. | |
| Returns dict with keys: bos, eos, pad, unk | |
| """ | |
| special_tokens = { | |
| "bos": None, | |
| "eos": None, | |
| "pad": None, | |
| "unk": None, | |
| } | |
| if tokenizer is None: | |
| return special_tokens | |
| # Try different ways to access special token IDs based on tokenizer type | |
| try: | |
| # For SentencePieceTokenizerWrapper | |
| if hasattr(tokenizer, 'bos_id'): | |
| special_tokens["bos"] = tokenizer.bos_id() | |
| if hasattr(tokenizer, 'eos_id'): | |
| special_tokens["eos"] = tokenizer.eos_id() | |
| if hasattr(tokenizer, 'pad_id'): | |
| special_tokens["pad"] = tokenizer.pad_id() | |
| if hasattr(tokenizer, 'unk_id'): | |
| special_tokens["unk"] = tokenizer.unk_id() | |
| # For HuggingFace tokenizers | |
| if hasattr(tokenizer, 'bos_token_id'): | |
| special_tokens["bos"] = tokenizer.bos_token_id | |
| if hasattr(tokenizer, 'eos_token_id'): | |
| special_tokens["eos"] = tokenizer.eos_token_id | |
| if hasattr(tokenizer, 'pad_token_id'): | |
| special_tokens["pad"] = tokenizer.pad_token_id | |
| if hasattr(tokenizer, 'unk_token_id'): | |
| special_tokens["unk"] = tokenizer.unk_token_id | |
| except Exception as e: | |
| pass # If extraction fails, keep defaults | |
| return special_tokens | |
| def count_special_tokens(token_ids: torch.Tensor, special_token_ids: Dict[str, Optional[int]]) -> Dict[str, int]: | |
| """ | |
| Count occurrences of special tokens in token IDs. | |
| Args: | |
| token_ids: 1D tensor of token IDs | |
| special_token_ids: Dict mapping special token names to IDs | |
| Returns: | |
| Dict with counts of each special token | |
| """ | |
| counts = {} | |
| # Convert to CPU numpy for efficient counting | |
| if isinstance(token_ids, torch.Tensor): | |
| ids_numpy = token_ids.cpu().numpy() | |
| else: | |
| ids_numpy = token_ids | |
| for token_name, token_id in special_token_ids.items(): | |
| if token_id is not None: | |
| count = (ids_numpy == token_id).sum() | |
| counts[token_name] = int(count) | |
| else: | |
| counts[token_name] = 0 | |
| return counts | |
| def print_sample(sample_idx: int, sample: Dict[str, torch.Tensor], tokenizer=None, | |
| special_token_ids: Optional[Dict[str, Optional[int]]] = None, max_display: int = 20): | |
| """Print a single sample with formatted information.""" | |
| print(f"\n Sample {sample_idx}:") | |
| # Input IDs | |
| if "input_ids" in sample: | |
| input_ids = sample["input_ids"] | |
| print(f" input_ids: {format_tensor_info(input_ids)}") | |
| print(f" Values: {input_ids[:max_display].tolist()}{'...' if len(input_ids) > max_display else ''}") | |
| labels = sample["labels"] | |
| print(f" labels: {format_tensor_info(labels)}") | |
| print(f" Values: {labels[:max_display].tolist()}{'...' if len(labels) > max_display else ''}") | |
| # Count special tokens | |
| if special_token_ids is not None: | |
| special_counts = count_special_tokens(input_ids, special_token_ids) | |
| special_summary = ", ".join([f"{name}={count}" for name, count in special_counts.items()]) | |
| print(f" Special tokens: {special_summary}") | |
| # Try to decode if tokenizer available | |
| if tokenizer is not None: | |
| try: | |
| decoded = tokenizer.decode(input_ids.tolist()[:max_display], skip_special_tokens=False) | |
| preview = decoded.replace('\n', '\\n') | |
| print(f" Decoded: {preview}...") | |
| except Exception as e: | |
| print(f" [Decode error: {e}]") | |
| # Attention mask | |
| if "attention_mask" in sample: | |
| attention_mask = sample["attention_mask"] | |
| print(f" attention_mask: {format_tensor_info(attention_mask)}") | |
| non_pad_count = attention_mask.sum().item() | |
| print(f" Non-padding tokens: {non_pad_count}/{len(attention_mask)}") | |
| # Labels | |
| if "labels" in sample: | |
| labels = sample["labels"] | |
| print(f" labels: {format_tensor_info(labels)}") | |
| valid_labels = labels[labels != -100] | |
| print(f" Valid labels (not -100): {len(valid_labels)}/{len(labels)}") | |
| def print_sample_sft(sample_idx: int, sample: Dict[str, torch.Tensor], mask: Optional[list] = None, | |
| tokenizer=None, special_token_ids: Optional[Dict[str, Optional[int]]] = None, | |
| max_display: int = 20): | |
| """ | |
| Print a single SFT sample with detailed masking information. | |
| Shows: | |
| - Which tokens are user input (mask=0, labeled -100) | |
| - Which tokens are assistant output (mask=1, labeled with token id) | |
| - Token decoding and special token information | |
| """ | |
| print(f"\n SFT Sample {sample_idx}:") | |
| if "input_ids" in sample: | |
| input_ids = sample["input_ids"] | |
| print(f" input_ids: {format_tensor_info(input_ids)}") | |
| labels = sample["labels"] | |
| print(f" labels: {format_tensor_info(labels)}") | |
| # Show mask breakdown | |
| if mask is not None: | |
| mask_array = mask | |
| user_count = sum(1 for m in mask_array if m == 0) | |
| assistant_count = sum(1 for m in mask_array if m == 1) | |
| print(f" Mask breakdown:") | |
| print(f" - User input tokens (mask=0, ignored): {user_count}") | |
| print(f" - Assistant output tokens (mask=1, trained): {assistant_count}") | |
| # Show which regions are which | |
| print(f" Token regions (first {max_display} tokens):") | |
| for i in range(min(max_display, len(input_ids))): | |
| token_id = input_ids[i].item() | |
| label = labels[i].item() | |
| mask_val = mask_array[i] if i < len(mask_array) else 0 | |
| region_type = "USER " if mask_val == 0 else "ASST " | |
| label_str = "IGNORE" if label == -100 else f"{label:5d}" | |
| # Try to decode token | |
| token_str = "?" | |
| if tokenizer is not None: | |
| try: | |
| token_str = tokenizer.decode([token_id], skip_special_tokens=False).replace('\n', '\\n')[:15] | |
| except: | |
| token_str = f"ID:{token_id}" | |
| print(f" [{i:3d}] {region_type} | label={label_str} | token={token_str}") | |
| if len(input_ids) > max_display: | |
| print(f" ... and {len(input_ids) - max_display} more tokens") | |
| else: | |
| # Fallback: just show labels | |
| masked_count = (labels == -100).sum().item() | |
| valid_count = (labels != -100).sum().item() | |
| print(f" Token labels breakdown:") | |
| print(f" - Masked tokens (label=-100): {masked_count}") | |
| print(f" - Training tokens: {valid_count}") | |
| def inspect_dataset(config_path: str, num_samples: int = 10, max_display: int = 20, show_sft: bool = False): | |
| """ | |
| Main inspection function. | |
| Args: | |
| config_path: Path to YAML config file | |
| num_samples: Number of samples to display (batches or individual samples) | |
| max_display: Max tokens to display in preview | |
| show_sft: If True, show SFT-specific masking information (requires SFT dataset) | |
| """ | |
| # ======================================================================== | |
| # Step 1: Load and validate config | |
| # ======================================================================== | |
| print_separator("STEP 1: LOAD CONFIGURATION") | |
| config_path = Path(config_path) | |
| if not config_path.exists(): | |
| print(f"✗ Config file not found: {config_path}") | |
| sys.exit(1) | |
| print(f"✓ Loading config from: {config_path}") | |
| # Try to load config - auto-detect mode or use default | |
| try: | |
| # Try to infer mode from filename or use default | |
| if "pretrain" in str(config_path).lower(): | |
| mode = TrainingModeEnum.PRETRAIN | |
| elif "sft" in str(config_path).lower(): | |
| mode = TrainingModeEnum.SFT | |
| elif "rl" in str(config_path).lower() or "dpo" in str(config_path).lower(): | |
| mode = TrainingModeEnum.RL | |
| else: | |
| mode = TrainingModeEnum.PRETRAIN # Default | |
| config = load_config(config_path, mode) | |
| print(f"✓ Config loaded (mode: {config.mode.value})") | |
| except Exception as e: | |
| print(f"✗ Failed to load config: {e}") | |
| sys.exit(1) | |
| # Print config summary | |
| print(f"\nConfiguration Summary:") | |
| print(f" - Mode: {config.mode.value}") | |
| print(f" - Model: {config.model.architecture_type.value}") | |
| print(f" - Vocab size: {config.model.vocab_size}") | |
| print(f" - Max seq length: {config.model.max_seq_length}") | |
| print(f" - Batch size: {config.batch_size}") | |
| print(f" - Dataset source: {'Local JSONL' if config.dataset.local else 'HuggingFace'}") | |
| if config.dataset.local: | |
| print(f" - JSONL path: {config.dataset.jsonl_path}") | |
| print(f" - Tokenizer: {config.dataset.tokenizer_type or 'auto-detect'}") | |
| if config.dataset.tokenizer_path: | |
| print(f" - Tokenizer path: {config.dataset.tokenizer_path}") | |
| print(f" - Samples per chunk: {config.dataset.samples_per_chunk}") | |
| print(f" - Tokenizer threads: {config.dataset.tokenizer_threads}") | |
| else: | |
| print(f" - Dataset name: {config.dataset.dataset_name}") | |
| print(f" - Dataset config: {config.dataset.config or 'default'}") | |
| print(f" - Split: {config.dataset.split}") | |
| if config.dataset.max_samples: | |
| print(f" - Max samples (limit): {config.dataset.max_samples}") | |
| # Print SFT-specific config if applicable | |
| if show_sft and hasattr(config, 'response_loss_only'): | |
| print(f"\n SFT Configuration:") | |
| print(f" - Response loss only: {config.response_loss_only}") | |
| user_token = getattr(config, 'user_token', '<user>') | |
| assistant_token = getattr(config, 'assistant_token', '<assistant>') | |
| print(f" - User token: {user_token}") | |
| print(f" - Assistant token: {assistant_token}") | |
| checkpoint_path = getattr(config, 'checkpoint_path', None) | |
| if checkpoint_path: | |
| print(f" - Checkpoint path: {checkpoint_path}") | |
| # ======================================================================== | |
| # Step 2: Setup device and random seed | |
| # ======================================================================== | |
| print_separator("STEP 2: SETUP DEVICE AND SEED") | |
| device = get_device(config.device) | |
| print(f"✓ Device: {device}") | |
| print(f"✓ CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f" - GPU: {torch.cuda.get_device_name(0)}") | |
| print(f" - CUDA version: {torch.version.cuda}") | |
| set_seed(config.seed) | |
| print(f"✓ Random seed set: {config.seed}") | |
| # ======================================================================== | |
| # Step 3: Create datasets | |
| # ======================================================================== | |
| print_separator("STEP 3: CREATE DATASETS") | |
| try: | |
| train_dataset, val_dataset = create_datasets(config) | |
| print(f"✓ Train dataset created: {type(train_dataset).__name__}") | |
| if val_dataset: | |
| print(f"✓ Validation dataset created: {type(val_dataset).__name__}") | |
| else: | |
| print(f"✓ No validation dataset (JSONL or not configured)") | |
| except Exception as e: | |
| print(f"✗ Failed to create datasets: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| print(f" - Train dataset length: {len(train_dataset)}") | |
| # ======================================================================== | |
| # Step 4: Setup async pipeline (for JSONL datasets) | |
| # ======================================================================== | |
| async_loader = None | |
| tokenizer = None | |
| special_token_ids = None | |
| if isinstance(train_dataset, BaseJSONLDataset): | |
| print_separator("STEP 4: SETUP ASYNC PIPELINE") | |
| # Extract components | |
| chunk_manager = train_dataset.chunk_manager | |
| tokenizer = train_dataset.tokenizer | |
| # Get special token IDs | |
| special_token_ids = get_special_token_ids(tokenizer) | |
| print(f"✓ Special token IDs extracted:") | |
| for token_name, token_id in special_token_ids.items(): | |
| print(f" - {token_name.upper()}: {token_id}") | |
| print(f"✓ ChunkManager found:") | |
| print(f" - Total chunks: {chunk_manager.num_chunks}") | |
| print(f" - Effective lines: {chunk_manager.effective_lines}") | |
| print(f" - File size: {chunk_manager.file_size_bytes / (1024**2):.1f} MB") | |
| print(f" - Chunk line ranges: {chunk_manager.chunk_line_ranges[:3]}..." if len(chunk_manager.chunk_line_ranges) > 3 else f" - Chunk line ranges: {chunk_manager.chunk_line_ranges}") | |
| print(f"✓ Tokenizer found: {type(tokenizer).__name__}") | |
| # Create tokenization queue | |
| print(f"✓ Creating TokenizationQueue...") | |
| try: | |
| tokenization_queue = TokenizationQueue( | |
| chunk_manager=chunk_manager, | |
| tokenizer=tokenizer, | |
| config=config, | |
| max_queue_size=4, | |
| shuffle_chunks=True, | |
| num_threads=config.dataset.tokenizer_threads, | |
| ) | |
| print(f" - Max queue size: 4") | |
| print(f" - Tokenizer threads: {config.dataset.tokenizer_threads}") | |
| print(f" - Total items: {tokenization_queue.total_items}") | |
| except Exception as e: | |
| print(f"✗ Failed to create TokenizationQueue: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| # Create async batch iterator | |
| print(f"✓ Creating AsyncBatchIterator...") | |
| try: | |
| async_loader = AsyncBatchIterator( | |
| tokenization_queue=tokenization_queue, | |
| batch_size=config.batch_size, | |
| device=device, | |
| drop_last=True, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| ) | |
| print(f" - Batch size: {config.batch_size}") | |
| print(f" - Device: {device}") | |
| print(f" - Gradient accumulation: {config.gradient_accumulation_steps}") | |
| except Exception as e: | |
| print(f"✗ Failed to create AsyncBatchIterator: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| else: | |
| print_separator("STEP 4: USING STANDARD DATALOADER") | |
| from taoTrain.data.loaders import get_dataloader | |
| tokenizer = getattr(train_dataset, "tokenizer", None) | |
| # Get special token IDs | |
| special_token_ids = get_special_token_ids(tokenizer) | |
| if tokenizer is not None: | |
| print(f"✓ Special token IDs extracted:") | |
| for token_name, token_id in special_token_ids.items(): | |
| print(f" - {token_name.upper()}: {token_id}") | |
| print(f"✓ HuggingFace dataset detected") | |
| print(f"✓ Will use standard DataLoader (batch_size={config.batch_size})") | |
| # Create standard dataloader | |
| try: | |
| train_loader = get_dataloader( | |
| train_dataset, | |
| config, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| async_loader = train_loader | |
| print(f"✓ DataLoader created") | |
| except Exception as e: | |
| print(f"✗ Failed to create DataLoader: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| # ======================================================================== | |
| # Step 5: Fetch and display samples | |
| # ======================================================================== | |
| print_separator(f"STEP 5: FETCH FIRST {num_samples} BATCHES") | |
| if async_loader is None: | |
| print(f"✗ No data loader available") | |
| sys.exit(1) | |
| total_samples = 0 | |
| total_tokens = 0 | |
| batch_idx = -1 # Initialize to track if any batches were processed | |
| cumulative_special_counts = {"bos": 0, "eos": 0, "pad": 0, "unk": 0} | |
| try: | |
| for batch_idx, batch in enumerate(tqdm(async_loader, total=num_samples, desc="Fetching batches")): | |
| if batch_idx >= num_samples: | |
| break | |
| print_separator(f"BATCH {batch_idx}", char="-", width=60) | |
| print(f"Batch keys: {batch.keys()}") | |
| # Print batch shapes | |
| for key, tensor in batch.items(): | |
| print(f" {key}: {format_tensor_info(tensor)}") | |
| # Count tokens and special tokens | |
| if "input_ids" in batch: | |
| batch_token_count = batch["input_ids"].numel() | |
| total_tokens += batch_token_count | |
| print(f" Total tokens in batch: {batch_token_count}") | |
| # Count special tokens in entire batch | |
| if special_token_ids is not None: | |
| batch_input_ids = batch["input_ids"].view(-1) # Flatten batch | |
| batch_special_counts = count_special_tokens(batch_input_ids, special_token_ids) | |
| special_str = ", ".join([f"{name}={count}" for name, count in batch_special_counts.items()]) | |
| print(f" Batch special tokens: {special_str}") | |
| # Accumulate counts | |
| for token_name, count in batch_special_counts.items(): | |
| cumulative_special_counts[token_name] += count | |
| # Print individual samples in batch | |
| batch_size = next(iter(batch.values())).shape[0] | |
| total_samples += batch_size | |
| print(f"\nSamples in batch (showing first 3 of {batch_size}):") | |
| for sample_idx in range(min(3, batch_size)): | |
| sample = {key: tensor[sample_idx] for key, tensor in batch.items()} | |
| # For SFT datasets, try to get mask info | |
| mask = None | |
| if show_sft and isinstance(train_dataset, SFTJSONLDataset): | |
| try: | |
| # Access the mask from the dataset's current chunk | |
| if (hasattr(train_dataset, '_current_chunk_data') and | |
| train_dataset._current_chunk_data is not None and | |
| "mask" in train_dataset._current_chunk_data): | |
| # We need to map from batch sample index back to dataset index | |
| # For simplicity, access the local chunk index | |
| chunk_idx = batch_idx * batch_size + sample_idx | |
| if chunk_idx < len(train_dataset._current_chunk_data["mask"]): | |
| mask = train_dataset._current_chunk_data["mask"][sample_idx] | |
| except Exception as e: | |
| print(f" [Could not access mask: {e}]") | |
| # Use SFT-specific print function | |
| print_sample_sft(batch_idx * batch_size + sample_idx, sample, mask, tokenizer, special_token_ids, max_display) | |
| else: | |
| # Use regular print function | |
| print_sample(batch_idx * batch_size + sample_idx, sample, tokenizer, special_token_ids, max_display) | |
| if batch_size > 3: | |
| print(f" ... and {batch_size - 3} more samples") | |
| except Exception as e: | |
| print(f"✗ Error during batch fetching: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| # ======================================================================== | |
| # Summary Statistics | |
| # ======================================================================== | |
| print_separator("SUMMARY STATISTICS") | |
| print(f"Total batches fetched: {min(batch_idx + 1, num_samples)}") | |
| print(f"Total samples: {total_samples}") | |
| print(f"Total tokens: {total_tokens}") | |
| if total_samples > 0: | |
| print(f"Average tokens per sample: {total_tokens / total_samples:.1f}") | |
| # Print special token summary | |
| if special_token_ids is not None: | |
| print(f"\nSpecial token counts across all batches:") | |
| for token_name, token_id in special_token_ids.items(): | |
| total_count = cumulative_special_counts[token_name] | |
| pct = (total_count / total_tokens * 100) if total_tokens > 0 else 0 | |
| print(f" - {token_name.upper()}: {total_count} tokens ({pct:.2f}%)") | |
| print_separator() | |
| print("✓ Inspection complete!") | |
| def main(): | |
| """Main entry point with argument parsing.""" | |
| parser = argparse.ArgumentParser( | |
| description="Test dataset loading pipeline - inspect data shapes, tokenization, and samples" | |
| ) | |
| parser.add_argument( | |
| "config_path", | |
| type=str, | |
| help="Path to YAML/JSON config file (e.g., config/pretrain.yaml)" | |
| ) | |
| parser.add_argument( | |
| "--num-samples", | |
| type=int, | |
| default=10, | |
| help="Number of batches to display (default: 10)" | |
| ) | |
| parser.add_argument( | |
| "--max-display", | |
| type=int, | |
| default=20, | |
| help="Maximum tokens to display in preview (default: 20)" | |
| ) | |
| parser.add_argument( | |
| "--sft", | |
| action="store_true", | |
| help="If set, show SFT-specific masking information (user vs assistant tokens)" | |
| ) | |
| args = parser.parse_args() | |
| inspect_dataset( | |
| args.config_path, | |
| num_samples=args.num_samples, | |
| max_display=args.max_display, | |
| show_sft=args.sft | |
| ) | |
| if __name__ == "__main__": | |
| main() | |