File size: 2,423 Bytes
fb11af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import torch.nn as nn

from . import logging


logger = logging.get_logger(__name__)


def pretty_print_trainable_parameters(model: nn.Module):
    trainable_parameters = []
    for n, p in model.named_parameters():
        if p.requires_grad:
            trainable_parameters.append(n)

    printable_results = {}
    for p in trainable_parameters:
        param_split = p.split(".")
        param_name = ""
        digit_index = 0
        layer_index_list = []
        for split_item in param_split:
            if split_item.isdigit():
                param_name += f"<{digit_index}>."
                layer_index_list.append(int(split_item))
                digit_index += 1
            else:
                param_name += f"{split_item}."
        param_name = param_name[:-1]

        if param_name not in printable_results:
            printable_results[param_name] = []
        printable_results[param_name].append(layer_index_list)

    train_param_info = "\n**** trainable parameters ****"
    for param_key in printable_results.keys():
        layer_idxs = np.array(printable_results[param_key])
        if layer_idxs.shape[-1] == 0:
            train_param_info += "\n" + param_key
            continue
        layer_min = layer_idxs.min(axis=0)
        layer_max = layer_idxs.max(axis=0)
        print_pattern = param_key
        for index in range(len(layer_min)):
            if layer_min[index] == layer_max[index]:
                print_pattern = print_pattern.replace(f"<{index}>", f"[{layer_min[index]}]")
            else:
                print_pattern = print_pattern.replace(f"<{index}>", f"[{layer_min[index]}-{layer_max[index]}]")
        train_param_info += "\n" + print_pattern
    train_param_info += "\n**** trainable parameters ****"
    logger.info_rank0(train_param_info)