| # modeling_ndlinear_dit.py | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| from mlp import NdMlp | |
| from ndlinear import NdLinear | |
| from models_hf import DiT, DiTConfig | |
| class DiTConfig(PretrainedConfig): | |
| model_type = "ndlinear_dit" | |
| class DiT(PreTrainedModel): | |
| config_class = DiTConfig | |
| __all__ = ["DiT", "DiTConfig"] | |