Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from cosmos_predict1.diffusion.training.utils.peft.lora_attn import build_attn_lora | |
| from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType | |
| from cosmos_predict1.utils import log | |
| from cosmos_predict1.utils.misc import count_params | |
| def get_all_lora_params(model): | |
| """ | |
| Get all LoRA weight parameters in the model | |
| """ | |
| lora_modules = [mod for name, mod in model.named_modules() if "lora.net.0" in name or "lora.net.1" in name] | |
| lora_params = [(name, param) for mod in lora_modules for name, param in mod.named_parameters()] | |
| log.info(f"Found {len(lora_params)} LoRA weight matrices") | |
| return lora_params | |
| def setup_lora_requires_grad(model): | |
| """ | |
| Freeze all model parameters except LoRA parameters. | |
| """ | |
| num_param = count_params(model, verbose=True) | |
| log.critical(f"Model has {num_param * 1e-6:.2f}M parameters before freezing") | |
| lora_params = get_all_lora_params(model) | |
| num_lora_param = sum([p.numel() for _, p in lora_params]) | |
| log.info(f"Total number of LoRA parameters: {num_lora_param * 1e-6:.2f}M") | |
| if num_lora_param > 0: | |
| log.info("Freezing all parameters") | |
| model.requires_grad_(False) | |
| log.info("Unfreezing LoRA parameters") | |
| for name, param in lora_params: | |
| # log.info(f"Unfreezing loRA : {name}") | |
| param.requires_grad_(True) | |
| num_param = count_params(model, verbose=True) | |
| log.critical(f"Model has {num_param * 1e-6:.2f}M parameters after freezing") | |
| return num_lora_param | |
| def add_lora_layers(model, peft_control_config): | |
| for i, block_name in enumerate(model.net.blocks): | |
| block = model.net.blocks[block_name] | |
| peft_control = peft_control_config.get(i, {}) | |
| for j, subblock in enumerate(block.blocks): | |
| block_type = subblock.block_type | |
| peft_control_subblock = peft_control.get(block_type.upper(), {}) | |
| customization_type = peft_control_subblock.get("customization_type", None) | |
| if customization_type == CustomizationType.LORA: | |
| if block_type.upper() in ["CA", "FA"]: | |
| build_attn_lora(subblock.block.attn, peft_control_subblock) | |