import argparse import logging import os import sys from dataclasses import dataclass from enum import Enum from typing import Dict, List, Protocol, Union import torch import torch.nn.utils.rnn as rnn_utils from PIL import Image import numpy as np import soundfile as sf from scipy import signal import io from megatron.training.global_vars import get_tokenizer sys.path.append( os.path.abspath( os.path.join( os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir, "examples/multimodal", ) ) ) from dataloader_provider import train_valid_test_dataloaders_provider from transformers import AutoProcessor from examples.mimo.data.utils.calculate_audio_tokens import calculate_num_audio_tokens from megatron.energon import ( DefaultTaskEncoder, VQASample, WorkerConfig, get_loader, get_train_dataset, ) from megatron.energon.task_encoder.base import stateless from megatron.training import get_args from megatron.training.tokenizer.multimodal_tokenizer import mistral_custom_template IMAGE_TOKEN = "" AUDIO_TOKEN = "