| import transformers | |
| def replace_llama_rmsnorm_with_fused_rmsnorm(): | |
| try: | |
| from apex.normalization import FusedRMSNorm | |
| from functools import partial | |
| LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa | |
| transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm | |
| print("Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm") | |
| except ImportError: | |
| # using the normal LlamaRMSNorm | |
| pass | |
| except Exception: | |
| print("discovered apex but it failed to load, falling back to LlamaRMSNorm") | |
| pass | |