| |
|
|
| """Utility functions for logging and printing MIMO model structure.""" |
|
|
| |
| from megatron.training.utils import print_rank_0 |
|
|
|
|
| def print_mimo_structure(model): |
| """Print a clean summary of MIMO model structure showing components and their types.""" |
| print_rank_0("MIMO Model Structure:") |
| |
| |
| print_rank_0("├── Modalities:") |
| if hasattr(model, 'modality_submodules'): |
| for modality_name, submodule in model.modality_submodules.items(): |
| print_rank_0(f"│ ├── {modality_name}") |
| |
| |
| if hasattr(submodule, 'encoders') and submodule.encoders: |
| print_rank_0("│ │ ├── Encoders:") |
| for encoder_name, encoder in submodule.encoders.items(): |
| encoder_type = encoder.__class__.__name__ |
| print_rank_0(f"│ │ │ ├── {encoder_name}: {encoder_type}") |
| |
| |
| if hasattr(submodule, 'input_projections') and submodule.input_projections: |
| print_rank_0("│ │ ├── Input Projections:") |
| for i, proj in enumerate(submodule.input_projections): |
| proj_type = proj.__class__.__name__ |
| print_rank_0(f"│ │ │ ├── {i}: {proj_type}") |
| |
| |
| if hasattr(submodule, 'decoders') and submodule.decoders: |
| print_rank_0("│ │ ├── Decoders:") |
| for decoder_name, decoder in submodule.decoders.items(): |
| decoder_type = decoder.__class__.__name__ |
| print_rank_0(f"│ │ │ ├── {decoder_name}: {decoder_type}") |
| |
| |
| if hasattr(submodule, 'output_projections') and submodule.output_projections: |
| print_rank_0("│ │ ├── Output Projections:") |
| for i, proj in enumerate(submodule.output_projections): |
| proj_type = proj.__class__.__name__ |
| print_rank_0("│ │ │ ├── {i}: {proj_type}") |
| |
| |
| if hasattr(model, 'language_model'): |
| lm_type = model.language_model.__class__.__name__ |
| print_rank_0(f"├── Language Model: {lm_type}") |