|
|
| from typing import List, Set |
| from natsort import natsorted |
| from pathlib import Path |
|
|
| def pretty_print_torch_module_keys( |
| keys: list, |
| indent: int = 4, |
| |
| |
| max_part_num: int = 2, |
| max_examples: int = 1, |
| show_counts: bool = True |
| ) -> None: |
| """ |
| Pretty print PyTorch module keys with hierarchical grouping. |
| |
| Args: |
| keys: List of parameter/buffer keys from state_dict |
| max_part_num: Maximum number of dot-separated parts to show (0=no truncation) |
| indent: Number of spaces for indentation |
| max_examples: Maximum example keys to show per group |
| show_counts: Whether to show count of keys in each group |
| """ |
| |
| from collections import defaultdict |
| groups = defaultdict(list) |
| for key in keys: |
| if max_part_num <= 0: |
| groups[key].append(key) |
| else: |
| |
| parts = key.split('.') |
| prefix = '.'.join(parts[:max_part_num]) if len(parts) > max_part_num else key |
| groups[prefix].append(key) |
|
|
| for prefix, members in sorted(groups.items()): |
| _s = f"{' ' * indent}{prefix}" |
| count_str = f" ({len(members)} keys)" if show_counts else "" |
| |
| print(_s) |
| |
| |
| examples = members[:max_examples] |
| for ex in examples: |
| |
| print(f"{' ' * (indent * 2)}{ex[len(prefix):]}") |
| |
| if len(members) > max_examples: |
| print(f"{' ' * (indent * 2)}... (and {len(members) - max_examples} more)") |
|
|
|
|
|
|
|
|
| def get_representative_moduleNames( |
| all_keys: List[str], |
| ignore_prefixes: tuple = tuple(), |
| keep_index: int = 0, treat_alpha_digit: bool = True) -> Set[str]: |
| """ |
| Filter state dict keys to keep only representative items (specific index in any numbered sequence). |
| Args: |
| all_keys: List of all keys from state_dict (all are leaf nodes) |
| eg. ['learnable_vector', 'model.diffusion_model.time_embed.0.weight', 'model.diffusion_model.time_embed.0.bias', |
| keep_index: Which index to keep when multiple numbered items exist (default 0 for first) |
| treat_alpha_digit: If True, also treat letter+digit combinations (e.g., 'attn1', 'attn2') as numbered sequences |
| Returns: |
| Set of filtered keys preserving only representative items |
| """ |
| import re |
| if ignore_prefixes: |
| all_keys = [k for k in all_keys if not any(k.startswith(p) for p in ignore_prefixes)] |
| num_pattern = re.compile(r'\.(\d+)\.') |
| |
| from collections import defaultdict |
| groups = defaultdict(list) |
| |
| for key in all_keys: |
| |
| pattern = re.sub(r'\.(\d+)\.', '.X.', key) |
| |
| pattern = re.sub(r'\.(\d+)$', '.X', pattern) |
| |
| if treat_alpha_digit: |
| |
| pattern = re.sub(r'\.([a-zA-Z]+)(\d+)\.', r'.\1X.', pattern) |
| pattern = re.sub(r'\.([a-zA-Z]+)(\d+)$', r'.\1X', pattern) |
| |
| groups[pattern].append(key) |
| |
| |
| filtered_keys = [] |
| for pattern, keys_in_group in groups.items(): |
| if len(keys_in_group) == 1: |
| |
| filtered_keys.extend(keys_in_group) |
| else: |
| |
| def get_numeric_indices(key): |
| |
| matches = re.findall(r'\.(\d+)(?:\.|$)', key) |
| indices = [int(x) for x in matches] |
| |
| if treat_alpha_digit: |
| |
| alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key) |
| for _, digit in alpha_digit_matches: |
| indices.append(int(digit)) |
| |
| return tuple(indices) |
| |
| |
| keys_in_group.sort(key=get_numeric_indices) |
| |
| |
| target_found = False |
| for key in keys_in_group: |
| if treat_alpha_digit: |
| |
| alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key) |
| for prefix, digit in alpha_digit_matches: |
| if int(digit) == keep_index: |
| filtered_keys.append(key) |
| target_found = True |
| break |
| if target_found: |
| break |
| else: |
| |
| indices = get_numeric_indices(key) |
| |
| if indices and indices[0] == keep_index: |
| filtered_keys.append(key) |
| target_found = True |
| break |
| |
| |
| if not target_found: |
| filtered_keys.append(keys_in_group[0]) |
| |
| filtered_keys = natsorted(filtered_keys) |
| return filtered_keys |
|
|
| def get_no_grad_and_has_grad_keys( |
| model, only_representative: bool = True, |
| ignore_prefixes: tuple = tuple(), |
| verbose: int = 1, |
| get_representative_moduleNames_at_first :bool = False, |
| save_path: str = None, |
| ): |
| |
| all_params = dict(model.named_parameters()) |
| keys = list(all_params.keys()) |
| |
| |
| log_messages = [] |
| |
| def print_(*msg, verb=1): |
| if verbose >= verb: |
| print(*msg) |
| if save_path is not None: |
| log_messages.extend(msg) |
| |
| if only_representative and get_representative_moduleNames_at_first: |
| keys = get_representative_moduleNames(keys, ignore_prefixes=ignore_prefixes) |
| |
| k_has_grad = [] |
| k_no_grad = [] |
| |
| for name in keys: |
| if name not in all_params: |
| print_(f"{name} not found in named_parameters (might be buffer)", verb=3) |
| k_no_grad.append(name) |
| continue |
| |
| param = all_params[name] |
| if param.requires_grad: |
| if param.grad is None: |
| print_(f"{name} has grad but grad is None", verb=3) |
| k_no_grad.append(name) |
| elif param.grad.sum() == 0: |
| print_(f"{name} has grad but grad is 0", verb=3) |
| k_no_grad.append(name) |
| else: |
| print_(f"{name} has grad !=0", verb=4) |
| k_has_grad.append(name) |
| else: |
| k_no_grad.append(name) |
| if only_representative and not get_representative_moduleNames_at_first: |
| k_no_grad = get_representative_moduleNames(k_no_grad, ignore_prefixes=ignore_prefixes) |
| k_has_grad = get_representative_moduleNames(k_has_grad, ignore_prefixes=ignore_prefixes) |
| |
| print_("No grad:", verb=2) |
| for name in k_no_grad: |
| print_(f" - {name}", verb=2) |
| print_("Has grad:", verb=2) |
| if 0: |
| print_("<skip.>", verb=2) |
| else: |
| for name in k_has_grad: |
| print_(f" - {name}", verb=2) |
| print_(f"Total: {len(k_no_grad) + len(k_has_grad)} {len(k_has_grad)=}", verb=1) |
| |
| if save_path is not None: |
| Path(save_path).write_text('\n'.join(log_messages), encoding='utf-8') |
| print(f"> {save_path}") |
| |
| return k_has_grad, k_no_grad |
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__=='__main__': |
| |
| all_keys = [ |
| 'face_ID_model.facenet.input_layer.0.weight', |
| 'face_ID_model.facenet.input_layer.1.weight', |
| 'face_ID_model.facenet.input_layer.1.bias', |
| 'face_ID_model.facenet.input_layer.1.running_mean', |
| 'face_ID_model.facenet.input_layer.1.running_var', |
| 'face_ID_model.facenet.input_layer.1.num_batches_tracked', |
| 'face_ID_model.facenet.input_layer.2.weight', |
| |
| 'learnable_vector', |
| 'model.diffusion_model_refNet.time_embed.0.weight', |
| 'model.diffusion_model_refNet.time_embed.0.weight.xxx', |
| 'model.diffusion_model_refNet.time_embed.0.bias', |
| 'model.diffusion_model_refNet.time_embed.0.xxxx.0', |
| 'model.diffusion_model_refNet.time_embed.0.xxxx.1', |
| 'model.diffusion_model_refNet.time_embed.0.xxxx.2', |
| |
| 'model.diffusion_model_refNet.time_embed.1.weight', |
| 'model.diffusion_model_refNet.time_embed.1.bias', |
| 'model.diffusion_model_refNet.time_embed.0.submodule.param', |
| 'model.diffusion_model_refNet.time_embed.1.submodule.param', |
| 'model.diffusion_model_refNet.input_blocks.0.weight', |
| 'model.diffusion_model_refNet.input_blocks.1.weight', |
| 'model.diffusion_model_refNet.middle_block.0.weight', |
| 'model.diffusion_model_refNet.output_blocks.0.bias', |
| 'model.diffusion_model_refNet.output_blocks.1.bias', |
| 'model.diffusion_model_refNet.output_blocks.2.bias', |
|
|
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight', |
| 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn3.xxxxx', |
| ] |
| |
|
|
| import torch |
| sd = torch.load('checkpoints/pretrained.ckpt') |
| all_keys = sd['state_dict'].keys() |
| filtered = get_representative_moduleNames(all_keys) |
| print(f"Filtered representative keys (keep_index=0, default):") |
| for key in sorted(filtered): |
| print(f" - {key}") |
| |
|
|