Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| """ | |
| Simple mock data module for testing MIMO with image-text (VLM) models. | |
| This module provides basic synthetic data generation for testing Vision Language Models | |
| within the MIMO framework. | |
| """ | |
| from typing import Callable, Dict, List, Optional | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| def create_mock_image(image_size: int = 336) -> torch.Tensor: | |
| """ | |
| Create a simple mock image (all zeros). | |
| Args: | |
| image_size: Size of the square image | |
| Returns: | |
| Tensor of shape [3, H, W] with all zeros | |
| """ | |
| return torch.zeros(3, image_size, image_size) | |
| def create_mock_caption() -> str: | |
| """ | |
| Create a simple mock caption. | |
| Returns: | |
| A simple caption string | |
| """ | |
| return "This is an image." | |
| class MockVLMDataset(Dataset): | |
| """Simple dataset of mock image-text pairs for VLM testing.""" | |
| def __init__( | |
| self, | |
| size: int = 10000, | |
| image_size: int = 336, | |
| seq_len: int = 512, | |
| image_seq_length: int = 32, | |
| vocab_size: int = 256, | |
| tokenizer: Optional[Callable] = None, | |
| pad_token_id: int = 0, | |
| image_token_id: int = 32000, | |
| ): | |
| """ | |
| Initialize the mock VLM dataset. | |
| Args: | |
| size: Number of examples in the dataset | |
| image_size: Size of the square images | |
| seq_len: Total length of the token sequence (image + text) | |
| image_seq_length: Number of image tokens to pad | |
| vocab_size: Size of the vocabulary for tokenization | |
| tokenizer: Optional tokenizer function | |
| pad_token_id: ID for padding token | |
| image_token_id: ID for image placeholder token | |
| """ | |
| self.size = size | |
| self.image_size = image_size | |
| self.seq_len = seq_len | |
| self.image_seq_length = image_seq_length | |
| self.vocab_size = vocab_size | |
| self.tokenizer = tokenizer | |
| # Special token IDs | |
| self.pad_token_id = pad_token_id | |
| self.image_token_id = image_token_id | |
| if self.seq_len < self.image_seq_length: | |
| raise ValueError( | |
| f"seq_len ({self.seq_len}) must be >= image_seq_length ({self.image_seq_length})." | |
| ) | |
| def __len__(self) -> int: | |
| """Return the size of the dataset.""" | |
| return self.size | |
| def __getitem__(self, idx: int) -> Dict: | |
| """ | |
| Get an item from the dataset. | |
| Args: | |
| idx: Index of the item (ignored, all items are identical) | |
| Returns: | |
| Dictionary containing: | |
| - images: Tensor of shape [C, H, W] | |
| - input_ids: Tokenized caption with image token | |
| - labels: Shifted input_ids for language modeling | |
| - loss_mask: Mask for loss calculation | |
| - position_ids: Position IDs for the tokens | |
| """ | |
| # Create a zero image | |
| image = create_mock_image(self.image_size) | |
| # Generate random token sequence for this sample. | |
| input_ids = self._mock_tokenize() | |
| # Create labels (shifted input_ids) | |
| labels = input_ids.clone() | |
| labels[:-1] = input_ids[1:] | |
| labels[-1] = self.pad_token_id # Padding for the last position | |
| # Set labels for image tokens to -100 (ignored in loss calculation) | |
| labels[input_ids == self.image_token_id] = -100 | |
| # Create loss mask (1 for tokens to calculate loss on, 0 for others) | |
| loss_mask = torch.ones_like(input_ids).float() | |
| loss_mask[input_ids == self.pad_token_id] = 0.0 # Don't calculate loss on padding | |
| loss_mask[input_ids == self.image_token_id] = 0.0 # Don't calculate loss on image tokens | |
| # Create position IDs (just sequential integers) | |
| position_ids = torch.arange(len(input_ids), dtype=torch.long) | |
| return { | |
| "input_ids": input_ids, | |
| "labels": labels, | |
| "loss_mask": loss_mask, | |
| "position_ids": position_ids, | |
| "modality_inputs": { | |
| "clip_encoder": { | |
| "images": image, | |
| } | |
| }, | |
| } | |
| def _mock_tokenize(self) -> torch.Tensor: | |
| """ | |
| Generate a mock token sequence consisting of ``image_seq_length`` image tokens followed by | |
| randomly generated text tokens such that the total sequence length equals | |
| ``self.seq_len``. | |
| Returns: | |
| torch.Tensor: Tensor of token IDs of shape ``[seq_len]``. | |
| """ | |
| # Image placeholder tokens ─ placed at the beginning of the sequence to mimic | |
| # the layout produced by many VLM tokenizers. | |
| image_tokens = torch.full( | |
| (self.image_seq_length,), self.image_token_id, dtype=torch.long | |
| ) | |
| # Random text tokens drawn uniformly in ``[1, vocab_size)`` (we reserve ``0`` for pad). | |
| num_text_tokens = self.seq_len - self.image_seq_length | |
| text_tokens = torch.randint( | |
| low=1, | |
| high=self.vocab_size, | |
| size=(num_text_tokens,), | |
| dtype=torch.long, | |
| ) | |
| # Concatenate to form the full sequence. | |
| token_ids = torch.cat((image_tokens, text_tokens), dim=0) | |
| return token_ids | |
| def get_mock_vlm_dataloader( | |
| batch_size: int = 8, | |
| dataset_size: int = 100, | |
| image_size: int = 224, | |
| seq_len: int = 77, | |
| image_seq_length: int = 32, | |
| num_workers: int = 0, | |
| pad_token_id: int = 0, | |
| image_token_id: int = 50000, | |
| ) -> DataLoader: | |
| """ | |
| Create a DataLoader for mock VLM data. | |
| Args: | |
| batch_size: Batch size | |
| dataset_size: Size of the dataset | |
| image_size: Size of the square images | |
| seq_len: Total length of the token sequence (image + text) | |
| image_seq_length: Number of image tokens to pad | |
| num_workers: Number of worker processes for data loading | |
| pad_token_id: ID for padding token | |
| image_token_id: ID for image placeholder token | |
| Returns: | |
| DataLoader for the mock VLM dataset | |
| """ | |
| dataset = MockVLMDataset( | |
| size=dataset_size, | |
| image_size=image_size, | |
| seq_len=seq_len, | |
| image_seq_length=image_seq_length, | |
| pad_token_id=pad_token_id, | |
| image_token_id=image_token_id, | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| collate_fn=lambda batch: _collate_fn(batch), | |
| ) | |
| return dataloader | |
| def _collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Collate function for the DataLoader. | |
| Args: | |
| batch: List of dictionaries from the dataset | |
| Returns: | |
| Dictionary of batched tensors | |
| """ | |
| images = torch.stack([item["images"] for item in batch]) | |
| input_ids = torch.stack([item["input_ids"] for item in batch]) | |
| labels = torch.stack([item["labels"] for item in batch]) | |
| loss_mask = torch.stack([item["loss_mask"] for item in batch]) | |
| position_ids = torch.stack([item["position_ids"] for item in batch]) | |
| return { | |
| "input_ids": input_ids, | |
| "labels": labels, | |
| "loss_mask": loss_mask, | |
| "position_ids": position_ids, | |
| "modality_inputs": { | |
| "clip_encoder": { | |
| "images": images, | |
| } | |
| }, | |
| } | |
| def train_valid_test_datasets_provider(train_val_test_num_samples): | |
| """Provide datasets for training, validation, and testing.""" | |
| from megatron.core import mpu | |
| from megatron.training import get_args | |
| args = get_args() | |
| # Print some info to confirm args are available | |
| print(f"Creating datasets with batch size: {args.micro_batch_size}") | |
| print(f"Image size: {args.image_size}") | |
| print(f"Image sequence length: {args.image_seq_length}") | |
| print(f"Total sequence length: {args.total_seq_length}") | |
| # Only build dataset on tensor parallel rank 0 | |
| if mpu.get_tensor_model_parallel_rank() == 0: | |
| from examples.mimo.data.mock import MockVLMDataset | |
| train_dataset = MockVLMDataset( | |
| size=train_val_test_num_samples[0], | |
| image_size=args.image_size, | |
| seq_len=args.total_seq_length, | |
| image_seq_length=args.image_seq_length, | |
| pad_token_id=args.pad_token_id, | |
| image_token_id=args.image_token_id, | |
| ) | |
| # Use the same dataset type for validation | |
| valid_dataset = MockVLMDataset( | |
| size=train_val_test_num_samples[1] if train_val_test_num_samples[1] > 0 else 100, | |
| image_size=args.image_size, | |
| seq_len=args.total_seq_length, | |
| image_seq_length=args.image_seq_length, | |
| pad_token_id=args.pad_token_id, | |
| image_token_id=args.image_token_id, | |
| ) | |
| # No test dataset for now | |
| test_dataset = None | |
| else: | |
| train_dataset = None | |
| valid_dataset = None | |
| test_dataset = None | |
| return train_dataset, valid_dataset, test_dataset | |
| if __name__ == "__main__": | |
| print("\nCreating mock VLM dataloader...") | |
| dataloader = get_mock_vlm_dataloader(batch_size=4, dataset_size=10) | |
| print(f"DataLoader has {len(dataloader)} batches") | |
| for batch in dataloader: | |
| print("\nBatch from dataloader:") | |
| for key, tensor in batch.items(): | |
| print(f" {key}: {tensor.shape}") | |
| break | |