Spaces:
Configuration error
Configuration error
| import os | |
| import torch | |
| from safetensors import safe_open | |
| from loguru import logger | |
| import gc | |
| from functools import lru_cache | |
| from tqdm import tqdm | |
| def GET_DTYPE(): | |
| RUNNING_FLAG = os.getenv("DTYPE") | |
| return RUNNING_FLAG | |
| class WanLoraWrapper: | |
| def __init__(self, wan_model): | |
| self.model = wan_model | |
| self.lora_metadata = {} | |
| # self.override_dict = {} # On CPU | |
| def load_lora(self, lora_path, lora_name=None): | |
| if lora_name is None: | |
| lora_name = os.path.basename(lora_path).split(".")[0] | |
| if lora_name in self.lora_metadata: | |
| logger.info(f"LoRA {lora_name} already loaded, skipping...") | |
| return lora_name | |
| self.lora_metadata[lora_name] = {"path": lora_path} | |
| logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}") | |
| return lora_name | |
| def _load_lora_file(self, file_path, param_dtype): | |
| with safe_open(file_path, framework="pt") as f: | |
| tensor_dict = {key: f.get_tensor(key).to(param_dtype) for key in f.keys()} | |
| return tensor_dict | |
| def apply_lora(self, lora_name, alpha=1.0, param_dtype=torch.bfloat16, device='cpu'): | |
| if lora_name not in self.lora_metadata: | |
| logger.info(f"LoRA {lora_name} not found. Please load it first.") | |
| lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"], param_dtype) | |
| # weight_dict = self.model.original_weight_dict | |
| self._apply_lora_weights(lora_weights, alpha, device) | |
| # self.model._init_weights(weight_dict) | |
| logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") | |
| return True | |
| def get_parameter_by_name(self, model, param_name): | |
| parts = param_name.split('.') | |
| current = model | |
| for part in parts: | |
| if part.isdigit(): | |
| current = current[int(part)] | |
| else: | |
| current = getattr(current, part) | |
| return current | |
| def _apply_lora_weights(self, lora_weights, alpha, device): | |
| lora_pairs = {} | |
| prefix = "diffusion_model." | |
| for key in lora_weights.keys(): | |
| if key.endswith("lora_down.weight") and key.startswith(prefix): | |
| base_name = key[len(prefix) :].replace("lora_down.weight", "weight") | |
| b_key = key.replace("lora_down.weight", "lora_up.weight") | |
| if b_key in lora_weights: | |
| lora_pairs[base_name] = (key, b_key) | |
| elif key.endswith("diff_b") and key.startswith(prefix): | |
| base_name = key[len(prefix) :].replace("diff_b", "bias") | |
| lora_pairs[base_name] = (key) | |
| elif key.endswith("diff") and key.startswith(prefix): | |
| base_name = key[len(prefix) :].replace("diff", "weight") | |
| lora_pairs[base_name] = (key) | |
| applied_count = 0 | |
| for name in tqdm(lora_pairs.keys(), desc="Loading LoRA weights"): | |
| param = self.get_parameter_by_name(self.model, name) | |
| if device == 'cpu': | |
| dtype = torch.float32 | |
| else: | |
| dtype = param.dtype | |
| if isinstance(lora_pairs[name], tuple): | |
| name_lora_A, name_lora_B = lora_pairs[name] | |
| lora_A = lora_weights[name_lora_A].to(device, dtype) | |
| lora_B = lora_weights[name_lora_B].to(device, dtype) | |
| delta = torch.matmul(lora_B, lora_A) * alpha | |
| delta = delta.to(param.device, param.dtype) | |
| param.add_(delta) | |
| else: | |
| name_lora = lora_pairs[name] | |
| delta = lora_weights[name_lora].to(param.device, dtype)* alpha | |
| delta = delta.to(param.device, param.dtype) | |
| param.add_(delta) | |
| applied_count += 1 | |
| logger.info(f"Applied {applied_count} LoRA weight adjustments") | |
| if applied_count == 0: | |
| logger.info( | |
| "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file." | |
| ) | |
| def list_loaded_loras(self): | |
| return list(self.lora_metadata.keys()) | |
| def get_current_lora(self): | |
| return self.model.current_lora |