File size: 5,489 Bytes
7f3dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import numpy as np
import os
import torch.nn as nn

from einops import rearrange, reduce, repeat

from .knowledge_encoder import Knowledge_Encoder
from .med_cpt import MedCPT
from .base_bert import BaseBERT
from train.dist import is_master

def compute_average_gradient(module):
    # 初始化梯度总和和参数计数
    total_gradient = 0.0
    total_params = 0
    
    # 遍历module的所有参数
    for param in module.parameters():
        if param.grad is not None:
            # 累加此参数的梯度绝对值
            total_gradient += param.grad.abs().mean().item()
            total_params += 1
    
    # 计算平均梯度
    if total_params > 0:
        average_gradient = total_gradient / total_params
    else:
        average_gradient = None
    
    return average_gradient

class Text_Encoder(nn.Module):
    def __init__(self, 
                 text_encoder,
                 checkpoint=None,
                 # other params
                 open_bert_layer=12,
                 open_modality_embed=False,
                 partial_load=False,
                 gpu_id=None,
                 device=None):
        super().__init__()

        self.device = device
        
        # choose text encoder
        class_name = {
            'ours': Knowledge_Encoder,
            'medcpt': MedCPT,
            'basebert': BaseBERT,
        }[text_encoder]
        
        model = class_name()
        model = model.to(device)
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)        
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_id], find_unused_parameters=True)
        
        # load checkpoint
        if checkpoint:
            if is_master():
                print(f"** QUERY ** Load encoder from {checkpoint}.")
                
            checkpoint = torch.load(checkpoint, map_location=device)
            checkpoint['model_state_dict'] = {k:v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower' not in k and 'temperature' not in k}
            if partial_load:
                model_dict =  model.state_dict()
                # check difference
                unexpected_state_dict = [k for k in checkpoint['model_state_dict'].keys() if k not in model_dict.keys()]
                missing_state_dict = [k for k in model_dict.keys() if k not in checkpoint['model_state_dict'].keys()]
                unmatchd_state_dict = [k for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape != model_dict[k].shape]
                # load partial parameters
                state_dict = {k:v for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape == model_dict[k].shape}
                model_dict.update(state_dict)
                model.load_state_dict(model_dict)
                if is_master():
                    print('The following parameters are unexpected in query generator checkpoint:\n', unexpected_state_dict)
                    print('The following parameters are missing in query generator checkpoint:\n', missing_state_dict)
                    print('The following parameters have different shapes in query generator checkpoint:\n', unmatchd_state_dict)
                    print('The following parameters are loaded in query generator :\n', state_dict.keys())
            else:
                model.load_state_dict(checkpoint['model_state_dict'])
                
        # open bert
        for name, param in model.named_parameters():
            if 'encoder.layer.' in name and int(name.split('encoder.layer.')[-1].split('.')[0])>open_bert_layer:  # encoder.layer.11.xxx --> 11
                param.requires_grad = True
            elif open_bert_layer < 11 and ('pooler' in name or 'mlp_embed' in name):
                param.requires_grad = True
            elif open_modality_embed and 'modality_embed' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
                
        self.model = model

    def forward(self, label_name, modality_name):
        """
        Args:
            label_name (List of List of Str / List of Str): B x N / N
            modality_name (List / Str): B / 1
            NOTE: a list of labels paired with one modality
            
        Return:
            queries (Tensor): B x N / N
        """
        if isinstance(label_name[0], list):
            batch_size = len(label_name)
            num_query = len(label_name[0])
            input_text = [t for t_ls in label_name for t in t_ls]    # BN
            modality = [mod for mod in modality_name for n in range(num_query)] # repeat each mod for N times -> BN
        else:
            num_query = len(label_name)
            input_text = label_name  # N
            modality = [modality_name for n in range(num_query)]   # N
            
        # name to code
        modality_code_dict = {
                'ct':0,
                'mri':1,
                'us':2,
                'pet':3,
                'microscopy':4
            }
        modality_code = torch.tensor([modality_code_dict[mod] for mod in modality])   # bn
            
        # get embed
        queries = self.model(input_text, modality_code, self.device)
        
        if isinstance(label_name[0], list):
            queries = rearrange(queries, '(b n) d -> b n d', b=batch_size, n=num_query)

        return queries