Spaces:
Running
Running
| # convert_lora_i2v_to_fc.py | |
| import torch | |
| import safetensors.torch | |
| import safetensors # Need this for safe_open | |
| import argparse | |
| import os | |
| import re # Regular expressions might be useful for more complex key parsing if needed | |
| # !!! IMPORTANT: Updated based on the output of analyze_wan_models.py !!! | |
| # The base layer name identified with shape mismatch. | |
| # Check your LoRA file's keys if they use a different prefix (e.g., 'transformer.') | |
| # Assuming the base name identified in LoRA keys matches this. | |
| BASE_LAYERS_TO_SKIP_LORA = { | |
| "patch_embedding", # The layer name from the analysis output | |
| # Add other layers here ONLY if the analysis revealed more mismatches | |
| } | |
| # !!! END IMPORTANT SECTION !!! | |
| def get_base_layer_name(lora_key: str, prefixes = ["lora_transformer_", "lora_unet_"]): | |
| """ | |
| Attempts to extract the base model layer name from a LoRA key. | |
| Handles common prefixes and suffixes. Adjust prefixes if needed. | |
| Example: "lora_transformer_patch_embedding_down.weight" -> "patch_embedding" | |
| "lora_transformer_blocks_0_attn_qkv.alpha" -> "blocks.0.attn.qkv" | |
| Args: | |
| lora_key (str): The key from the LoRA state dictionary. | |
| prefixes (list[str]): A list of potential prefixes used in LoRA keys. | |
| Returns: | |
| str: The inferred base model layer name. | |
| """ | |
| cleaned_key = lora_key | |
| # Remove known prefixes | |
| for prefix in prefixes: | |
| if cleaned_key.startswith(prefix): | |
| cleaned_key = cleaned_key[len(prefix):] | |
| break # Assume only one prefix matches | |
| # Remove known suffixes | |
| # Order matters slightly if one suffix is part of another; list longer ones first if needed | |
| known_suffixes = [ | |
| ".lora_up.weight", | |
| ".lora_down.weight", | |
| "_lora_up.weight", # Include underscore variants just in case | |
| "_lora_down.weight", | |
| ".alpha" | |
| ] | |
| for suffix in known_suffixes: | |
| if cleaned_key.endswith(suffix): | |
| cleaned_key = cleaned_key[:-len(suffix)] | |
| break | |
| # Replace underscores used by some training scripts with periods for consistency | |
| # if the original model uses periods (like typical PyTorch modules). | |
| # Adjust this logic if the base model itself uses underscores extensively. | |
| cleaned_key = cleaned_key.replace("_", ".") | |
| # Specific fix for the target layer if prefix/suffix removal was incomplete or ambiguous | |
| # This is somewhat heuristic and might need adjustment based on exact LoRA key naming. | |
| if cleaned_key.startswith("patch.embedding"): # Handle case where prefix removal was incomplete | |
| # Map potential variants back to the canonical name found in analysis | |
| cleaned_key = "patch_embedding" | |
| elif cleaned_key == "patch.embedding.weight": # If suffix removal left .weight attached somehow | |
| cleaned_key = "patch_embedding" | |
| # Add elif clauses here if other specific key mappings are needed | |
| return cleaned_key | |
| def convert_lora(source_lora_path: str, target_lora_path: str): | |
| """ | |
| Converts an i2v_14B LoRA to be compatible with i2v_14B_FC by | |
| removing LoRA weights associated with layers that have incompatible shapes. | |
| Args: | |
| source_lora_path (str): Path to the input LoRA file (.safetensors). | |
| target_lora_path (str): Path to save the converted LoRA file (.safetensors). | |
| """ | |
| print(f"Loading source LoRA from: {source_lora_path}") | |
| if not os.path.exists(source_lora_path): | |
| print(f"Error: Source file not found: {source_lora_path}") | |
| return | |
| try: | |
| # Load tensors and metadata using safe_open for better handling | |
| source_lora_state_dict = {} | |
| metadata = {} | |
| with safetensors.safe_open(source_lora_path, framework="pt", device="cpu") as f: | |
| metadata = f.metadata() # Get metadata if it exists | |
| if metadata is None: # Ensure metadata is a dict even if empty | |
| metadata = {} | |
| for key in f.keys(): | |
| source_lora_state_dict[key] = f.get_tensor(key) # Load tensors | |
| print(f"Successfully loaded {len(source_lora_state_dict)} tensors.") | |
| if metadata: | |
| print(f"Found metadata: {metadata}") | |
| else: | |
| print("No metadata found.") | |
| except Exception as e: | |
| print(f"Error loading LoRA file: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| target_lora_state_dict = {} | |
| skipped_keys = [] | |
| kept_keys = [] | |
| base_name_map = {} # Store mapping for reporting | |
| print(f"\nConverting LoRA weights...") | |
| print(f"Will skip LoRA weights targeting these base layers: {BASE_LAYERS_TO_SKIP_LORA}") | |
| # Iterate through the loaded tensors | |
| for key, tensor in source_lora_state_dict.items(): | |
| # Use the helper function to extract the base layer name | |
| base_layer_name = get_base_layer_name(key) | |
| base_name_map[key] = base_layer_name # Store for reporting purposes | |
| # Check if the identified base layer name should be skipped | |
| if base_layer_name in BASE_LAYERS_TO_SKIP_LORA: | |
| skipped_keys.append(key) | |
| else: | |
| # Keep the tensor if its base layer is not in the skip list | |
| target_lora_state_dict[key] = tensor | |
| kept_keys.append(key) | |
| # --- Reporting --- | |
| print(f"\nConversion Summary:") | |
| print(f" - Total Tensors in Source: {len(source_lora_state_dict)}") | |
| print(f" - Kept {len(kept_keys)} LoRA weight tensors.") | |
| print(f" - Skipped {len(skipped_keys)} LoRA weight tensors (due to incompatible base layer shape):") | |
| if skipped_keys: | |
| max_print = 15 # Show a few more skipped keys if desired | |
| skipped_sorted = sorted(skipped_keys) # Sort for consistent output order | |
| for i, key in enumerate(skipped_sorted): | |
| base_name = base_name_map.get(key, "N/A") # Get the identified base name | |
| print(f" - {key} (Base Layer Identified: {base_name})") | |
| if i >= max_print -1 and len(skipped_keys) > max_print: | |
| print(f" ... and {len(skipped_keys) - max_print} more.") | |
| break | |
| else: | |
| print(" None") | |
| # --- Saving --- | |
| print(f"\nSaving converted LoRA ({len(target_lora_state_dict)} tensors) to: {target_lora_path}") | |
| try: | |
| # Save the filtered state dictionary with the original metadata | |
| safetensors.torch.save_file(target_lora_state_dict, target_lora_path, metadata=metadata) | |
| print("Conversion successful!") | |
| except Exception as e: | |
| print(f"Error saving converted LoRA file: {e}") | |
| if __name__ == "__main__": | |
| # Setup argument parser | |
| parser = argparse.ArgumentParser( | |
| description="Convert Wan i2v_14B LoRA to i2v_14B_FC LoRA by removing incompatible patch_embedding weights.", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
| ) | |
| parser.add_argument("source_lora", type=str, help="Path to the source i2v_14B LoRA file (.safetensors).") | |
| parser.add_argument("target_lora", type=str, help="Path to save the converted i2v_14B_FC LoRA file (.safetensors).") | |
| # Parse arguments | |
| args = parser.parse_args() | |
| # --- Input Validation --- | |
| if not os.path.exists(args.source_lora): | |
| print(f"Error: Source LoRA file not found at '{args.source_lora}'") | |
| elif not args.source_lora.lower().endswith(".safetensors"): | |
| print(f"Warning: Source file '{args.source_lora}' does not have a .safetensors extension.") | |
| elif args.source_lora == args.target_lora: | |
| print(f"Error: Source and target paths cannot be the same ('{args.source_lora}'). Choose a different target path.") | |
| elif os.path.exists(args.target_lora): | |
| print(f"Warning: Target file '{args.target_lora}' already exists and will be overwritten.") | |
| # Optionally add a --force flag or prompt user here | |
| convert_lora(args.source_lora, args.target_lora) | |
| else: | |
| # Run the conversion if basic checks pass | |
| convert_lora(args.source_lora, args.target_lora) |