from typing import List, Set from natsort import natsorted from pathlib import Path def pretty_print_torch_module_keys( keys: list, indent: int = 4, # max_part_num: int = 3, # max_examples: int = 2, max_part_num: int = 2, max_examples: int = 1, show_counts: bool = True ) -> None: """ Pretty print PyTorch module keys with hierarchical grouping. Args: keys: List of parameter/buffer keys from state_dict max_part_num: Maximum number of dot-separated parts to show (0=no truncation) indent: Number of spaces for indentation max_examples: Maximum example keys to show per group show_counts: Whether to show count of keys in each group """ # Group keys by their truncated prefix from collections import defaultdict groups = defaultdict(list) for key in keys: if max_part_num <= 0: # No truncation groups[key].append(key) else: # Split into parts and rejoin the first N parts parts = key.split('.') prefix = '.'.join(parts[:max_part_num]) if len(parts) > max_part_num else key groups[prefix].append(key) for prefix, members in sorted(groups.items()): _s = f"{' ' * indent}{prefix}" count_str = f" ({len(members)} keys)" if show_counts else "" # _s += f"{count_str}:" print(_s) # Show example keys (full paths) examples = members[:max_examples] for ex in examples: # print(f"{' ' * (indent * 2)}- {ex[len(prefix):]}") print(f"{' ' * (indent * 2)}{ex[len(prefix):]}") if len(members) > max_examples: print(f"{' ' * (indent * 2)}... (and {len(members) - max_examples} more)") def get_representative_moduleNames( all_keys: List[str], ignore_prefixes: tuple = tuple(), keep_index: int = 0, treat_alpha_digit: bool = True) -> Set[str]: """ Filter state dict keys to keep only representative items (specific index in any numbered sequence). Args: all_keys: List of all keys from state_dict (all are leaf nodes) eg. ['learnable_vector', 'model.diffusion_model.time_embed.0.weight', 'model.diffusion_model.time_embed.0.bias', keep_index: Which index to keep when multiple numbered items exist (default 0 for first) treat_alpha_digit: If True, also treat letter+digit combinations (e.g., 'attn1', 'attn2') as numbered sequences Returns: Set of filtered keys preserving only representative items """ import re if ignore_prefixes: all_keys = [k for k in all_keys if not any(k.startswith(p) for p in ignore_prefixes)] num_pattern = re.compile(r'\.(\d+)\.') # Pattern to match numbers in paths (e.g., '.0.', '.1.', etc.) # Group keys by their pattern (replace numbers with X for grouping) from collections import defaultdict groups = defaultdict(list) for key in all_keys: # Create a pattern by replacing all numbers with 'X' pattern = re.sub(r'\.(\d+)\.', '.X.', key) # Also handle numbers at the end of the key pattern = re.sub(r'\.(\d+)$', '.X', pattern) if treat_alpha_digit: # Also replace letter+digit combinations (e.g., 'attn1' -> 'attnX') pattern = re.sub(r'\.([a-zA-Z]+)(\d+)\.', r'.\1X.', pattern) pattern = re.sub(r'\.([a-zA-Z]+)(\d+)$', r'.\1X', pattern) groups[pattern].append(key) # print(f"Debug groups: {groups}") filtered_keys = [] for pattern, keys_in_group in groups.items(): if len(keys_in_group) == 1: # Only one key in this pattern group - keep it filtered_keys.extend(keys_in_group) else: # Multiple keys - find the one with the target index def get_numeric_indices(key): # Extract all numeric indices from the key (pure numbers) matches = re.findall(r'\.(\d+)(?:\.|$)', key) indices = [int(x) for x in matches] if treat_alpha_digit: # Also extract indices from letter+digit combinations alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key) for _, digit in alpha_digit_matches: indices.append(int(digit)) return tuple(indices) # Sort by numeric indices keys_in_group.sort(key=get_numeric_indices) # Try to find the key with the desired index target_found = False for key in keys_in_group: if treat_alpha_digit: # For alpha+digit mode, check if any alpha+digit combination has the target index alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key) for prefix, digit in alpha_digit_matches: if int(digit) == keep_index: filtered_keys.append(key) target_found = True break if target_found: break else: # For normal mode, check pure numeric indices indices = get_numeric_indices(key) # Check if the first (primary) index matches keep_index if indices and indices[0] == keep_index: filtered_keys.append(key) target_found = True break # If target index not found, fall back to the first available if not target_found: filtered_keys.append(keys_in_group[0]) filtered_keys = natsorted(filtered_keys) return filtered_keys def get_no_grad_and_has_grad_keys( model, only_representative: bool = True, ignore_prefixes: tuple = tuple(), verbose: int = 1, # for print (not for file save. for save, we log all ) 0,1: only print at last, 2: print at each step get_representative_moduleNames_at_first :bool = False, save_path: str = None, # if not None, save detailed log to file ): # don't use state_dict() (it lacks gradient information) all_params = dict(model.named_parameters()) keys = list(all_params.keys()) # For file logging, collect all messages log_messages = [] def print_(*msg, verb=1): if verbose >= verb: print(*msg) if save_path is not None: log_messages.extend(msg) if only_representative and get_representative_moduleNames_at_first: keys = get_representative_moduleNames(keys, ignore_prefixes=ignore_prefixes) k_has_grad = [] k_no_grad = [] # dont require grad or .grad is 0 for name in keys: if name not in all_params: print_(f"{name} not found in named_parameters (might be buffer)", verb=3) k_no_grad.append(name) continue param = all_params[name] if param.requires_grad: if param.grad is None: print_(f"{name} has grad but grad is None", verb=3) k_no_grad.append(name) elif param.grad.sum() == 0: print_(f"{name} has grad but grad is 0", verb=3) k_no_grad.append(name) else: print_(f"{name} has grad !=0", verb=4) k_has_grad.append(name) else: k_no_grad.append(name) if only_representative and not get_representative_moduleNames_at_first: k_no_grad = get_representative_moduleNames(k_no_grad, ignore_prefixes=ignore_prefixes) k_has_grad = get_representative_moduleNames(k_has_grad, ignore_prefixes=ignore_prefixes) print_("No grad:", verb=2) for name in k_no_grad: print_(f" - {name}", verb=2) print_("Has grad:", verb=2) if 0: print_("", verb=2) else: for name in k_has_grad: print_(f" - {name}", verb=2) print_(f"Total: {len(k_no_grad) + len(k_has_grad)} {len(k_has_grad)=}", verb=1) if save_path is not None: Path(save_path).write_text('\n'.join(log_messages), encoding='utf-8') # !diskW print(f"> {save_path}") return k_has_grad, k_no_grad if __name__=='__main__': # Example usage: all_keys = [ 'face_ID_model.facenet.input_layer.0.weight', 'face_ID_model.facenet.input_layer.1.weight', 'face_ID_model.facenet.input_layer.1.bias', 'face_ID_model.facenet.input_layer.1.running_mean', 'face_ID_model.facenet.input_layer.1.running_var', 'face_ID_model.facenet.input_layer.1.num_batches_tracked', 'face_ID_model.facenet.input_layer.2.weight', 'learnable_vector', 'model.diffusion_model_refNet.time_embed.0.weight', 'model.diffusion_model_refNet.time_embed.0.weight.xxx', 'model.diffusion_model_refNet.time_embed.0.bias', 'model.diffusion_model_refNet.time_embed.0.xxxx.0', 'model.diffusion_model_refNet.time_embed.0.xxxx.1', 'model.diffusion_model_refNet.time_embed.0.xxxx.2', 'model.diffusion_model_refNet.time_embed.1.weight', 'model.diffusion_model_refNet.time_embed.1.bias', 'model.diffusion_model_refNet.time_embed.0.submodule.param', 'model.diffusion_model_refNet.time_embed.1.submodule.param', 'model.diffusion_model_refNet.input_blocks.0.weight', 'model.diffusion_model_refNet.input_blocks.1.weight', 'model.diffusion_model_refNet.middle_block.0.weight', 'model.diffusion_model_refNet.output_blocks.0.bias', 'model.diffusion_model_refNet.output_blocks.1.bias', 'model.diffusion_model_refNet.output_blocks.2.bias', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn3.xxxxx', ] import torch sd = torch.load('checkpoints/pretrained.ckpt') all_keys = sd['state_dict'].keys() filtered = get_representative_moduleNames(all_keys) print(f"Filtered representative keys (keep_index=0, default):") for key in sorted(filtered): print(f" - {key}")