Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| """Utility functions for logging and printing MIMO model structure.""" | |
| # Use Megatron utility if available – covers both distributed and non-distributed cases. | |
| from megatron.training.utils import print_rank_0 | |
| def print_mimo_structure(model): | |
| """Print a clean summary of MIMO model structure showing components and their types.""" | |
| print_rank_0("MIMO Model Structure:") | |
| # Print modality submodules and their components | |
| print_rank_0("├── Modalities:") | |
| if hasattr(model, 'modality_submodules'): | |
| for modality_name, submodule in model.modality_submodules.items(): | |
| print_rank_0(f"│ ├── {modality_name}") | |
| # Print encoders | |
| if hasattr(submodule, 'encoders') and submodule.encoders: | |
| print_rank_0("│ │ ├── Encoders:") | |
| for encoder_name, encoder in submodule.encoders.items(): | |
| encoder_type = encoder.__class__.__name__ | |
| print_rank_0(f"│ │ │ ├── {encoder_name}: {encoder_type}") | |
| # Print input projections | |
| if hasattr(submodule, 'input_projections') and submodule.input_projections: | |
| print_rank_0("│ │ ├── Input Projections:") | |
| for i, proj in enumerate(submodule.input_projections): | |
| proj_type = proj.__class__.__name__ | |
| print_rank_0(f"│ │ │ ├── {i}: {proj_type}") | |
| # Print decoders | |
| if hasattr(submodule, 'decoders') and submodule.decoders: | |
| print_rank_0("│ │ ├── Decoders:") | |
| for decoder_name, decoder in submodule.decoders.items(): | |
| decoder_type = decoder.__class__.__name__ | |
| print_rank_0(f"│ │ │ ├── {decoder_name}: {decoder_type}") | |
| # Print output projections | |
| if hasattr(submodule, 'output_projections') and submodule.output_projections: | |
| print_rank_0("│ │ ├── Output Projections:") | |
| for i, proj in enumerate(submodule.output_projections): | |
| proj_type = proj.__class__.__name__ | |
| print_rank_0("│ │ │ ├── {i}: {proj_type}") | |
| # Print language model | |
| if hasattr(model, 'language_model'): | |
| lm_type = model.language_model.__class__.__name__ | |
| print_rank_0(f"├── Language Model: {lm_type}") |