File size: 11,072 Bytes
2b534de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

from typing import List, Set
from natsort import natsorted
from pathlib import Path

def pretty_print_torch_module_keys(
    keys: list,
    indent: int = 4,
    # max_part_num: int = 3,
    # max_examples: int = 2,
    max_part_num: int = 2,
    max_examples: int = 1,
    show_counts: bool = True
) -> None:
    """
    Pretty print PyTorch module keys with hierarchical grouping.
    
    Args:
        keys: List of parameter/buffer keys from state_dict
        max_part_num: Maximum number of dot-separated parts to show (0=no truncation)
        indent: Number of spaces for indentation
        max_examples: Maximum example keys to show per group
        show_counts: Whether to show count of keys in each group
    """
    # Group keys by their truncated prefix
    from collections import defaultdict
    groups = defaultdict(list)
    for key in keys:
        if max_part_num <= 0:  # No truncation
            groups[key].append(key)
        else:
            # Split into parts and rejoin the first N parts
            parts = key.split('.')
            prefix = '.'.join(parts[:max_part_num]) if len(parts) > max_part_num else key
            groups[prefix].append(key)

    for prefix, members in sorted(groups.items()):
        _s = f"{' ' * indent}{prefix}"
        count_str = f" ({len(members)} keys)" if show_counts else ""
        # _s += f"{count_str}:"
        print(_s)
        
        # Show example keys (full paths)
        examples = members[:max_examples]
        for ex in examples:
            # print(f"{' ' * (indent * 2)}- {ex[len(prefix):]}")
            print(f"{' ' * (indent * 2)}{ex[len(prefix):]}")
        
        if len(members) > max_examples:
            print(f"{' ' * (indent * 2)}... (and {len(members) - max_examples} more)")




def get_representative_moduleNames(
    all_keys: List[str], 
    ignore_prefixes: tuple = tuple(),
    keep_index: int = 0, treat_alpha_digit: bool = True) -> Set[str]:
    """
    Filter state dict keys to keep only representative items (specific index in any numbered sequence).
    Args:
        all_keys: List of all keys from state_dict (all are leaf nodes)
            eg. ['learnable_vector', 'model.diffusion_model.time_embed.0.weight', 'model.diffusion_model.time_embed.0.bias',
        keep_index: Which index to keep when multiple numbered items exist (default 0 for first)
        treat_alpha_digit: If True, also treat letter+digit combinations (e.g., 'attn1', 'attn2') as numbered sequences
    Returns:
        Set of filtered keys preserving only representative items
    """
    import re
    if ignore_prefixes:
        all_keys = [k for k in all_keys if not any(k.startswith(p) for p in ignore_prefixes)]
    num_pattern = re.compile(r'\.(\d+)\.') # Pattern to match numbers in paths (e.g., '.0.', '.1.', etc.)
    # Group keys by their pattern (replace numbers with X for grouping)
    from collections import defaultdict
    groups = defaultdict(list)
    
    for key in all_keys:
        # Create a pattern by replacing all numbers with 'X'
        pattern = re.sub(r'\.(\d+)\.', '.X.', key)
        # Also handle numbers at the end of the key
        pattern = re.sub(r'\.(\d+)$', '.X', pattern)
        
        if treat_alpha_digit:
            # Also replace letter+digit combinations (e.g., 'attn1' -> 'attnX')
            pattern = re.sub(r'\.([a-zA-Z]+)(\d+)\.', r'.\1X.', pattern)
            pattern = re.sub(r'\.([a-zA-Z]+)(\d+)$', r'.\1X', pattern)
        
        groups[pattern].append(key)
    # print(f"Debug groups: {groups}")
    
    filtered_keys = []
    for pattern, keys_in_group in groups.items():
        if len(keys_in_group) == 1:
            # Only one key in this pattern group - keep it
            filtered_keys.extend(keys_in_group)
        else:
            # Multiple keys - find the one with the target index
            def get_numeric_indices(key):
                # Extract all numeric indices from the key (pure numbers)
                matches = re.findall(r'\.(\d+)(?:\.|$)', key)
                indices = [int(x) for x in matches]
                
                if treat_alpha_digit:
                    # Also extract indices from letter+digit combinations
                    alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key)
                    for _, digit in alpha_digit_matches:
                        indices.append(int(digit))
                
                return tuple(indices)
            
            # Sort by numeric indices 
            keys_in_group.sort(key=get_numeric_indices)
            
            # Try to find the key with the desired index
            target_found = False
            for key in keys_in_group:
                if treat_alpha_digit:
                    # For alpha+digit mode, check if any alpha+digit combination has the target index
                    alpha_digit_matches = re.findall(r'\.([a-zA-Z]+)(\d+)(?:\.|$)', key)
                    for prefix, digit in alpha_digit_matches:
                        if int(digit) == keep_index:
                            filtered_keys.append(key)
                            target_found = True
                            break
                    if target_found:
                        break
                else:
                    # For normal mode, check pure numeric indices
                    indices = get_numeric_indices(key)
                    # Check if the first (primary) index matches keep_index  
                    if indices and indices[0] == keep_index:
                        filtered_keys.append(key)
                        target_found = True
                        break
            
            # If target index not found, fall back to the first available
            if not target_found:
                filtered_keys.append(keys_in_group[0])
    
    filtered_keys = natsorted(filtered_keys)
    return filtered_keys

def get_no_grad_and_has_grad_keys(
    model, only_representative: bool = True, 
    ignore_prefixes: tuple = tuple(),
    verbose: int = 1,  # for print (not for file save. for save, we log all ) 0,1: only print at last, 2: print at each step
    get_representative_moduleNames_at_first :bool = False,
    save_path: str = None,  # if not None, save detailed log to file
):
    # don't use state_dict() (it lacks gradient information)
    all_params = dict(model.named_parameters())
    keys = list(all_params.keys())
    
    # For file logging, collect all messages
    log_messages = []
    
    def print_(*msg, verb=1):
        if verbose >= verb:
            print(*msg)
        if save_path is not None:
            log_messages.extend(msg)
    
    if only_representative and get_representative_moduleNames_at_first:
        keys = get_representative_moduleNames(keys, ignore_prefixes=ignore_prefixes)
    
    k_has_grad = []
    k_no_grad = []  # dont require grad or .grad is 0
    
    for name in keys:
        if name not in all_params:
            print_(f"{name} not found in named_parameters (might be buffer)", verb=3)
            k_no_grad.append(name)
            continue
            
        param = all_params[name]
        if param.requires_grad:
            if param.grad is None:
                print_(f"{name} has grad but grad is None", verb=3)
                k_no_grad.append(name)
            elif param.grad.sum() == 0:
                print_(f"{name} has grad but grad is 0", verb=3)
                k_no_grad.append(name)
            else:
                print_(f"{name} has grad !=0", verb=4)
                k_has_grad.append(name)
        else:
            k_no_grad.append(name)
    if only_representative and not get_representative_moduleNames_at_first:
        k_no_grad  = get_representative_moduleNames(k_no_grad,  ignore_prefixes=ignore_prefixes)
        k_has_grad = get_representative_moduleNames(k_has_grad, ignore_prefixes=ignore_prefixes)
            
    print_("No grad:", verb=2)
    for name in k_no_grad:
        print_(f"  - {name}", verb=2)
    print_("Has grad:", verb=2)
    if 0:
        print_("<skip.>", verb=2)
    else:
        for name in k_has_grad:
            print_(f"  - {name}", verb=2)
    print_(f"Total: {len(k_no_grad) + len(k_has_grad)} {len(k_has_grad)=}", verb=1)
    
    if save_path is not None:
        Path(save_path).write_text('\n'.join(log_messages), encoding='utf-8')  # !diskW
        print(f"> {save_path}")
    
    return k_has_grad, k_no_grad






if __name__=='__main__':
    # Example usage:
    all_keys = [
        'face_ID_model.facenet.input_layer.0.weight',
        'face_ID_model.facenet.input_layer.1.weight',
        'face_ID_model.facenet.input_layer.1.bias',
        'face_ID_model.facenet.input_layer.1.running_mean',
        'face_ID_model.facenet.input_layer.1.running_var',
        'face_ID_model.facenet.input_layer.1.num_batches_tracked',
        'face_ID_model.facenet.input_layer.2.weight',
        
        'learnable_vector',
        'model.diffusion_model_refNet.time_embed.0.weight',
        'model.diffusion_model_refNet.time_embed.0.weight.xxx',
        'model.diffusion_model_refNet.time_embed.0.bias',
        'model.diffusion_model_refNet.time_embed.0.xxxx.0',
        'model.diffusion_model_refNet.time_embed.0.xxxx.1',
        'model.diffusion_model_refNet.time_embed.0.xxxx.2',
        
        'model.diffusion_model_refNet.time_embed.1.weight',
        'model.diffusion_model_refNet.time_embed.1.bias',
        'model.diffusion_model_refNet.time_embed.0.submodule.param',
        'model.diffusion_model_refNet.time_embed.1.submodule.param',
        'model.diffusion_model_refNet.input_blocks.0.weight',
        'model.diffusion_model_refNet.input_blocks.1.weight',
        'model.diffusion_model_refNet.middle_block.0.weight',
        'model.diffusion_model_refNet.output_blocks.0.bias',
        'model.diffusion_model_refNet.output_blocks.1.bias',
        'model.diffusion_model_refNet.output_blocks.2.bias',

        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight',
        'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn3.xxxxx',
    ]
    

    import torch
    sd = torch.load('checkpoints/pretrained.ckpt')
    all_keys = sd['state_dict'].keys()
    filtered = get_representative_moduleNames(all_keys)
    print(f"Filtered representative keys (keep_index=0, default):")
    for key in sorted(filtered):
        print(f"  - {key}")