Elfenreigen commited on
Commit
09cc9e3
·
verified ·
1 Parent(s): e695911

Upload 8 files

Browse files
factory/.DS_Store ADDED
Binary file (6.15 kB). View file
 
factory/loss.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ try:
8
+ import torch.distributed.nn
9
+ from torch import distributed as dist
10
+ has_distributed = True
11
+ except ImportError:
12
+ has_distributed = False
13
+
14
+ try:
15
+ import horovod.torch as hvd
16
+ except ImportError:
17
+ hvd = None
18
+
19
+
20
+ def gather_features(
21
+ image_features,
22
+ text_features,
23
+ local_loss=False,
24
+ gather_with_grad=False,
25
+ rank=0,
26
+ world_size=1,
27
+ use_horovod=False
28
+ ):
29
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
30
+ if use_horovod:
31
+ assert hvd is not None, 'Please install horovod'
32
+ if gather_with_grad:
33
+ all_image_features = hvd.allgather(image_features)
34
+ all_text_features = hvd.allgather(text_features)
35
+ else:
36
+ with torch.no_grad():
37
+ all_image_features = hvd.allgather(image_features)
38
+ all_text_features = hvd.allgather(text_features)
39
+ if not local_loss:
40
+ # ensure grads for local rank when all_* features don't have a gradient
41
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
42
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
43
+ gathered_image_features[rank] = image_features
44
+ gathered_text_features[rank] = text_features
45
+ all_image_features = torch.cat(gathered_image_features, dim=0)
46
+ all_text_features = torch.cat(gathered_text_features, dim=0)
47
+ else:
48
+ # We gather tensors from all gpus
49
+ if gather_with_grad:
50
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
51
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
52
+ else:
53
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
54
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
55
+ dist.all_gather(gathered_image_features, image_features)
56
+ dist.all_gather(gathered_text_features, text_features)
57
+ if not local_loss:
58
+ # ensure grads for local rank when all_* features don't have a gradient
59
+ gathered_image_features[rank] = image_features
60
+ gathered_text_features[rank] = text_features
61
+ all_image_features = torch.cat(gathered_image_features, dim=0)
62
+ all_text_features = torch.cat(gathered_text_features, dim=0)
63
+
64
+ return all_image_features, all_text_features
65
+
66
+
67
+ class ClipLoss(nn.Module):
68
+ def __init__(
69
+ self,
70
+ local_loss=False,
71
+ gather_with_grad=False,
72
+ cache_labels=False,
73
+ rank=0,
74
+ world_size=1,
75
+ use_horovod=False,
76
+ ):
77
+ super().__init__()
78
+ self.local_loss = local_loss
79
+ self.gather_with_grad = gather_with_grad
80
+ self.cache_labels = cache_labels
81
+ self.rank = rank
82
+ self.world_size = world_size
83
+ self.use_horovod = use_horovod
84
+
85
+ # cache state
86
+ self.prev_num_logits = 0
87
+ self.labels = {}
88
+
89
+ def forward(self, image_features, text_features):
90
+ logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
91
+ device = image_features.device
92
+ if self.world_size > 1:
93
+ all_image_features, all_text_features = gather_features(
94
+ image_features, text_features,
95
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
96
+
97
+ if self.local_loss:
98
+ logits_per_image = logit_scale * image_features @ all_text_features.T
99
+ logits_per_text = logit_scale * text_features @ all_image_features.T
100
+ else:
101
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
102
+ logits_per_text = logits_per_image.T
103
+ else:
104
+ logits_per_image = logit_scale * image_features @ text_features.T
105
+ logits_per_text = logit_scale * text_features @ image_features.T
106
+
107
+ # calculated ground-truth and cache if enabled
108
+ num_logits = logits_per_image.shape[0]
109
+ labels = torch.eye(num_logits, device=device, dtype=torch.float)
110
+ pred_1 = F.log_softmax(logits_per_image,dim=-1)
111
+ pred_2 = F.log_softmax(logits_per_text,dim=-1)
112
+ loss_a = F.kl_div(pred_1, labels,reduction = 'sum')/num_logits
113
+ loss_b = F.kl_div(pred_2, labels,reduction = 'sum')/num_logits
114
+ total_loss = (loss_a + loss_b)/2
115
+ return total_loss
116
+
117
+ class AsymmetricLoss(nn.Module):
118
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
119
+ super(AsymmetricLoss, self).__init__()
120
+
121
+ self.gamma_neg = gamma_neg
122
+ self.gamma_pos = gamma_pos
123
+ self.clip = clip
124
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
125
+ self.eps = eps
126
+
127
+ def forward(self, x, y, use_weight = False):
128
+ """"
129
+ Parameters
130
+ ----------
131
+ x: input logits
132
+ y: targets (multi-label binarized vector)
133
+ """
134
+
135
+ # Calculating Probabilities
136
+ x_sigmoid = torch.sigmoid(x)
137
+ xs_pos = x_sigmoid
138
+ xs_neg = 1 - x_sigmoid
139
+
140
+ # Asymmetric Clipping
141
+ if self.clip is not None and self.clip > 0:
142
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
143
+
144
+ # Basic CE calculation
145
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
146
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
147
+ loss = los_pos + los_neg
148
+
149
+ # Asymmetric Focusing
150
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
151
+ if self.disable_torch_grad_focal_loss:
152
+ torch.set_grad_enabled(False)
153
+ pt0 = xs_pos * y
154
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
155
+ pt = pt0 + pt1
156
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
157
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
158
+ if self.disable_torch_grad_focal_loss:
159
+ torch.set_grad_enabled(True)
160
+ loss *= one_sided_w
161
+ if use_weight:
162
+ return loss
163
+ return -loss.sum()
164
+
165
+ class RalSingleLoss(nn.Module):
166
+ '''
167
+ This loss is intended for single-label classification problems
168
+ '''
169
+ def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, epsilon_pos_pow = -2.5, reduction='mean'):
170
+ super(RalSingleLoss, self).__init__()
171
+
172
+ self.eps = eps
173
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
174
+ self.targets_classes = []
175
+ self.gamma_pos = gamma_pos
176
+ self.gamma_neg = gamma_neg
177
+ self.reduction = reduction
178
+ self.epsilon_pos = 1.0
179
+ self.epsilon_neg = 0.0
180
+ self.epsilon_pos_pow = epsilon_pos_pow
181
+ self.lamb = 1.5
182
+
183
+ def forward(self, inputs, target):
184
+ '''
185
+ "input" dimensions: - (batch_size,number_classes)
186
+ "target" dimensions: - (batch_size)
187
+ '''
188
+ num_classes = inputs.size()[-1]
189
+ log_preds = self.logsoftmax(inputs)
190
+ self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
191
+
192
+ # ASL weights
193
+ targets = self.targets_classes
194
+ anti_targets = 1 - targets
195
+ xs_pos = torch.exp(log_preds)
196
+ xs_neg = 1 - xs_pos
197
+ xs_pos = torch.exp(log_preds)* (torch.log(xs_pos.clamp(min=self.eps)) + self.epsilon_pos * (1 - xs_pos.clamp(min=self.eps)) + self.epsilon_pos_pow * 0.5 * torch.pow(1 - xs_pos.clamp(min=self.eps), 2) ) * torch.log(xs_pos)
198
+ xs_neg = (1 - xs_pos) * (torch.log(xs_neg.clamp(min=self.eps)) + self.epsilon_neg * (xs_neg.clamp(min=self.eps)) ) * -(self.lamb - xs_neg) * xs_neg ** 2
199
+ asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
200
+ self.gamma_pos * targets + self.gamma_neg * anti_targets)
201
+ log_preds = log_preds * asymmetric_w
202
+
203
+ if self.eps > 0: # label smoothing
204
+ self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
205
+
206
+ # loss calculation
207
+ loss = - self.targets_classes.mul(log_preds)
208
+
209
+ loss = loss.sum(dim=-1)
210
+ if self.reduction == 'mean':
211
+ loss = loss.mean()
212
+
213
+ return loss
214
+
215
+ class Ralloss(nn.Module):
216
+ def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-8, lamb=1.5, epsilon_neg=0.0, epsilon_pos=1.0, epsilon_pos_pow=-2.5, disable_torch_grad_focal_loss=False):
217
+ super(Ralloss, self).__init__()
218
+
219
+ self.gamma_neg = gamma_neg
220
+ self.gamma_pos = gamma_pos
221
+ self.clip = clip
222
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
223
+ self.eps = eps
224
+
225
+ # parameters of Taylor expansion polynomials
226
+ self.epsilon_pos = epsilon_pos
227
+ self.epsilon_neg = epsilon_neg
228
+ self.epsilon_pos_pow = epsilon_pos_pow
229
+ self.margin = 1.0
230
+ self.lamb = lamb
231
+
232
+ def forward(self, x, y, use_weight=False):
233
+ """"
234
+ x: input logits with size (batch_size, number of labels).
235
+ y: binarized multi-label targets with size (batch_size, number of labels).
236
+ """
237
+ # Calculating Probabilities
238
+ x_sigmoid = torch.sigmoid(x)
239
+ xs_pos = x_sigmoid
240
+ xs_neg = 1 - x_sigmoid
241
+
242
+ # Asymmetric Clipping
243
+ if self.clip is not None and self.clip > 0:
244
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
245
+
246
+ # Basic Taylor expansion polynomials
247
+ los_pos = y * (torch.log(xs_pos.clamp(min=self.eps)) + self.epsilon_pos * (1 - xs_pos.clamp(min=self.eps)) + self.epsilon_pos_pow * 0.5 * torch.pow(1 - xs_pos.clamp(min=self.eps), 2))
248
+ los_neg = (1 - y) * (torch.log(xs_neg.clamp(min=self.eps)) + self.epsilon_neg * (xs_neg.clamp(min=self.eps)) ) * (self.lamb - x_sigmoid) * x_sigmoid ** 2 * (self.lamb - xs_neg)
249
+ loss = los_pos + los_neg
250
+
251
+ # Asymmetric Focusing
252
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
253
+ if self.disable_torch_grad_focal_loss:
254
+ torch.set_grad_enabled(False)
255
+ pt0 = xs_pos * y
256
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
257
+ pt = pt0 + pt1
258
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
259
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
260
+ if self.disable_torch_grad_focal_loss:
261
+ torch.set_grad_enabled(True)
262
+ loss *= one_sided_w
263
+ if use_weight:
264
+ return loss
265
+ return -loss.sum()
266
+
267
+
268
+
269
+
factory/metric.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sklearn
2
+ from sklearn.metrics import roc_auc_score,precision_recall_curve,accuracy_score,matthews_corrcoef
3
+
4
+ # gt_np = [0,1,1,0]
5
+ # pred_np = [0.2,0.7,0.8,0.3]
6
+ # fps,tps,thresholds = sklearn.metrics._ranking._binary_clf_curve(gt_np, pred_np)
7
+ # precision, recall, thresholds = precision_recall_curve(gt_np, pred_np)
8
+ # print(fps,tps,thresholds) # tps An increasing count of true positives
9
+ # print(precision[:-1], recall[:-1], thresholds)
10
+ # tns = tps/recall[:-1] - tps
11
+
12
+ # print(tns + fps + tps)
13
+ # fns =
14
+ # [0. 1. 2.] [2. 2. 2.] [0.8 0.3 0.2]
15
+ # precision, recall, thresholds = precision_recall_curve(gt_np, pred_np)
16
+ # numerator = 2 * recall * precision
17
+ # denom = recall + precision
18
+ # f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
19
+ # max_f1 = np.max(f1_scores)
20
+ # max_f1_thresh = thresholds[np.argmax(f1_scores)]
21
+ # print('The max_f1_thresh is', max_f1_thresh)
22
+ # print('The average f1_score is', max_f1)
23
+ # print('The average acc_score is', accuracy_score(gt_np, pred_np>max_f1_thresh))
factory/utils.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import io
3
+ import os
4
+ import time
5
+ import random
6
+ from collections import defaultdict, deque
7
+ import datetime
8
+ import subprocess
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+ class SmoothedValue(object):
13
+ """Track a series of values and provide access to smoothed values over a
14
+ window or the global series average.
15
+ """
16
+
17
+ def __init__(self, window_size=20, fmt=None):
18
+ if fmt is None:
19
+ fmt = "{median:.4f} ({global_avg:.4f})"
20
+ self.deque = deque(maxlen=window_size)
21
+ self.total = 0.0
22
+ self.count = 0
23
+ self.fmt = fmt
24
+
25
+ def update(self, value, n=1):
26
+ self.deque.append(value)
27
+ self.count += n
28
+ self.total += value * n
29
+
30
+ def synchronize_between_processes(self):
31
+ """
32
+ Warning: does not synchronize the deque!
33
+ """
34
+ if not is_dist_avail_and_initialized():
35
+ return
36
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
37
+ dist.barrier()
38
+ dist.all_reduce(t)
39
+ t = t.tolist()
40
+ self.count = int(t[0])
41
+ self.total = t[1]
42
+
43
+ @property
44
+ def median(self):
45
+ d = torch.tensor(list(self.deque))
46
+ return d.median().item()
47
+
48
+ @property
49
+ def avg(self):
50
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
51
+ return d.mean().item()
52
+
53
+ @property
54
+ def global_avg(self):
55
+ if self.count == 0:
56
+ return self.total
57
+ else:
58
+ return self.total / self.count
59
+
60
+ @property
61
+ def max(self):
62
+ return max(self.deque)
63
+
64
+ @property
65
+ def value(self):
66
+ return self.deque[-1]
67
+
68
+ def __str__(self):
69
+ return self.fmt.format(
70
+ median=self.median,
71
+ avg=self.avg,
72
+ global_avg=self.global_avg,
73
+ max=self.max,
74
+ value=self.value
75
+ )
76
+
77
+
78
+ class MetricLogger(object):
79
+ def __init__(self, delimiter="\t"):
80
+ self.meters = defaultdict(SmoothedValue)
81
+ self.delimiter = delimiter
82
+
83
+ def update(self, **kwargs):
84
+ for k, v in kwargs.items():
85
+ if isinstance(v, torch.Tensor):
86
+ v = v.item()
87
+ assert isinstance(v, (float, int))
88
+ self.meters[k].update(v)
89
+
90
+ def __getattr__(self, attr):
91
+ if attr in self.meters:
92
+ return self.meters[attr]
93
+ if attr in self.__dict__:
94
+ return self.__dict__[attr]
95
+ raise AttributeError("'{}' object has no attribute '{}'".format(
96
+ type(self).__name__, attr))
97
+
98
+ def __str__(self):
99
+ loss_str = []
100
+ for name, meter in self.meters.items():
101
+ loss_str.append(
102
+ "{}: {}".format(name, str(meter))
103
+ )
104
+ return self.delimiter.join(loss_str)
105
+
106
+ def global_avg(self):
107
+ loss_str = []
108
+ for name, meter in self.meters.items():
109
+ loss_str.append(
110
+ "{}: {:.4f}".format(name, meter.global_avg)
111
+ )
112
+ return self.delimiter.join(loss_str)
113
+
114
+ def synchronize_between_processes(self):
115
+ for meter in self.meters.values():
116
+ meter.synchronize_between_processes()
117
+
118
+ def add_meter(self, name, meter):
119
+ self.meters[name] = meter
120
+
121
+ def log_every(self, iterable, print_freq, header=None):
122
+ i = 0
123
+ if not header:
124
+ header = ''
125
+ start_time = time.time()
126
+ end = time.time()
127
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
128
+ data_time = SmoothedValue(fmt='{avg:.4f}')
129
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
130
+ log_msg = [
131
+ header,
132
+ '[{0' + space_fmt + '}/{1}]',
133
+ 'eta: {eta}',
134
+ '{meters}',
135
+ 'time: {time}',
136
+ 'data: {data}'
137
+ ]
138
+ if torch.cuda.is_available():
139
+ log_msg.append('max mem: {memory:.0f}')
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ print(log_msg.format(
151
+ i, len(iterable), eta=eta_string,
152
+ meters=str(self),
153
+ time=str(iter_time), data=str(data_time),
154
+ memory=torch.cuda.max_memory_allocated() / MB))
155
+ else:
156
+ print(log_msg.format(
157
+ i, len(iterable), eta=eta_string,
158
+ meters=str(self),
159
+ time=str(iter_time), data=str(data_time)))
160
+ i += 1
161
+ end = time.time()
162
+ total_time = time.time() - start_time
163
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164
+ print('{} Total time: {} ({:.4f} s / it)'.format(
165
+ header, total_time_str, total_time / len(iterable)))
166
+
167
+
168
+
169
+ class AttrDict(dict):
170
+ def __init__(self, *args, **kwargs):
171
+ super(AttrDict, self).__init__(*args, **kwargs)
172
+ self.__dict__ = self
173
+
174
+
175
+ def compute_acc(logits, label, reduction='mean'):
176
+ ret = (torch.argmax(logits, dim=1) == label).float()
177
+ if reduction == 'none':
178
+ return ret.detach()
179
+ elif reduction == 'mean':
180
+ return ret.mean().item()
181
+
182
+ def compute_n_params(model, return_str=True):
183
+ tot = 0
184
+ for p in model.parameters():
185
+ w = 1
186
+ for x in p.shape:
187
+ w *= x
188
+ tot += w
189
+ if return_str:
190
+ if tot >= 1e6:
191
+ return '{:.1f}M'.format(tot / 1e6)
192
+ else:
193
+ return '{:.1f}K'.format(tot / 1e3)
194
+ else:
195
+ return tot
196
+
197
+ def setup_for_distributed(is_master):
198
+ """
199
+ This function disables printing when not in master process
200
+ """
201
+ import builtins as __builtin__
202
+ builtin_print = __builtin__.print
203
+
204
+ def print(*args, **kwargs):
205
+ force = kwargs.pop('force', False)
206
+ if is_master or force:
207
+ builtin_print(*args, **kwargs)
208
+
209
+ __builtin__.print = print
210
+
211
+
212
+ def seed_worker(worker_id):
213
+ worker_seed = torch.initial_seed() % 2**32
214
+ np.random.seed(worker_seed)
215
+ random.seed(worker_seed)
216
+
217
+ def is_dist_avail_and_initialized():
218
+ if not dist.is_available():
219
+ return False
220
+ if not dist.is_initialized():
221
+ return False
222
+ return True
223
+
224
+
225
+ def get_world_size():
226
+ if not is_dist_avail_and_initialized():
227
+ return 1
228
+ return dist.get_world_size()
229
+
230
+
231
+ def get_rank():
232
+ if not is_dist_avail_and_initialized():
233
+ return 0
234
+ return dist.get_rank()
235
+
236
+
237
+ def is_main_process():
238
+ return get_rank() == 0
239
+
240
+
241
+ def save_on_master(*args, **kwargs):
242
+ if is_main_process():
243
+ torch.save(*args, **kwargs)
244
+
245
+ def init_distributed_mode(args):
246
+ if args.dist_on_itp:
247
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
248
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
249
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
250
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
251
+ os.environ['LOCAL_RANK'] = str(args.gpu)
252
+ os.environ['RANK'] = str(args.rank)
253
+ os.environ['WORLD_SIZE'] = str(args.world_size)
254
+ print('on tip')
255
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
256
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
257
+ args.rank = int(os.environ["RANK"])
258
+ args.world_size = int(os.environ['WORLD_SIZE'])
259
+ args.gpu = int(os.environ['LOCAL_RANK'])
260
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
261
+ print('rank')
262
+ elif 'SLURM_PROCID' in os.environ:
263
+ args.rank = int(os.environ['SLURM_PROCID'])
264
+ args.gpu = args.rank % torch.cuda.device_count()
265
+ else:
266
+ print('Not using distributed mode')
267
+ args.distributed = False
268
+ return
269
+
270
+ # args.distributed = False
271
+
272
+ # torch.cuda.set_device(args.gpu)
273
+ # args.dist_backend = 'gloo'
274
+ # print('| distributed init (rank {}): {}, gpu {}'.format(
275
+ # args.rank, args.dist_url, args.gpu), flush=True)
276
+ # print("flag1")
277
+ # print(args)
278
+ # torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
279
+ # world_size=args.world_size, rank=args.rank)
280
+
281
+ # print("flag2")
282
+ # torch.distributed.barrier()
283
+ # setup_for_distributed(args.rank == 0)
284
+
285
+ args.distributed = False
286
+ args.dist_url ='tcp://localhost:12345'
287
+ args.world_size=1
288
+ args.rank = 0
289
+
290
+
291
+ # def init_distributed_mode(args,port='29511'):
292
+ # num_gpus = torch.cuda.device_count()
293
+ # if args.dist_on_itp:
294
+ # args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
295
+ # args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
296
+ # args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
297
+ # args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
298
+ # os.environ['LOCAL_RANK'] = str(args.gpu)
299
+ # os.environ['RANK'] = str(args.rank)
300
+ # os.environ['WORLD_SIZE'] = str(args.world_size)
301
+ # # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
302
+ # elif "SLURM_JOB_ID" in os.environ:
303
+ # print('SLURM_JOB_ID')
304
+ # args.rank = int(os.environ["SLURM_PROCID"])
305
+ # args.world_size = int(os.environ["SLURM_NTASKS"])
306
+ # node_list = os.environ["SLURM_NODELIST"]
307
+ # addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
308
+ # # specify master port
309
+ # if port is not None:
310
+ # os.environ["MASTER_PORT"] = str(port)
311
+ # elif "MASTER_PORT" not in os.environ:
312
+ # os.environ["MASTER_PORT"] = "29400"
313
+ # if "MASTER_ADDR" not in os.environ:
314
+ # os.environ["MASTER_ADDR"] = addr
315
+ # os.environ["WORLD_SIZE"] = str(args.world_size)
316
+ # os.environ["LOCAL_RANK"] = str(args.rank % num_gpus)
317
+ # os.environ["RANK"] = os.environ["WORLD_SIZE"]
318
+ # args.gpu = args.rank % torch.cuda.device_count()
319
+ # elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
320
+ # print('RANK')
321
+ # args.rank = int(os.environ["RANK"])
322
+ # args.world_size = int(os.environ['WORLD_SIZE'])
323
+ # args.gpu = int(os.environ['LOCAL_RANK'])
324
+ # else:
325
+ # print('Not using distributed mode')
326
+ # args.distributed = False
327
+ # return
328
+
329
+ # args.distributed = True
330
+
331
+ # torch.cuda.set_device(args.gpu)
332
+ # args.dist_backend = 'nccl'
333
+
334
+ # print('| distributed init (rank {}): {}, gpu {}'.format(
335
+ # args.rank, args.dist_url, args.gpu), flush=True)
336
+ # torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
337
+ # world_size=args.world_size, rank=args.rank)
338
+ # print('Init_process_group')
339
+ # torch.distributed.barrier()
340
+ # print('distributed.barrier')
341
+ # setup_for_distributed(args.rank == 0)
342
+ # print('Finished distributed')
343
+
344
+
345
+ # def init_distributed_mode(args):
346
+ # # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
347
+ # # args.local_rank = os.environ['LOCAL_RANK']
348
+ # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
349
+ # args.rank = int(os.environ["RANK"])
350
+ # args.world_size = int(os.environ['WORLD_SIZE'])
351
+ # args.local_rank = int(os.environ['LOCAL_RANK'])
352
+ # elif 'SLURM_PROCID' in os.environ:
353
+ # args.rank = int(os.environ['SLURM_PROCID'])
354
+ # args.local_rank = args.rank % torch.cuda.device_count()
355
+ # else:
356
+ # print('Not using distributed mode')
357
+ # args.distributed = False
358
+ # return
359
+
360
+ # args.distributed = True
361
+
362
+ # torch.cuda.set_device(args.local_rank)
363
+ # args.dist_backend = 'nccl'
364
+ # print('| distributed init (rank {}): {}'.format(
365
+ # args.rank, args.dist_url), flush=True)
366
+ # torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
367
+ # world_size=args.world_size, rank=args.rank)
368
+ # torch.distributed.barrier()
369
+ # setup_for_distributed(args.rank == 0)
370
+
371
+ # def init_distributed_mode(args):
372
+
373
+ # # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
374
+ # # args.rank = int(os.environ["RANK"])
375
+ # # args.world_size = int(os.environ['WORLD_SIZE'])
376
+ # # args.gpu = int(os.environ['LOCAL_RANK'])
377
+ # # elif 'SLURM_PROCID' in os.environ:
378
+ # # args.rank = int(os.environ['SLURM_PROCID'])
379
+ # # args.gpu = args.rank % torch.cuda.device_count()
380
+ # # else:
381
+ # # print('Not using distributed mode')
382
+ # # args.distributed = False
383
+ # # return
384
+ # # rank = int(os.environ['RANK']) # system env process ranks\
385
+ # # print(torch.distributed.get_world_size())
386
+
387
+ # args.distributed = True
388
+ # # torch.cuda.set_device(args.gpu)
389
+ # num_gpus = torch.cuda.device_count() # Returns the number of GPUs available
390
+ # torch.cuda.set_device(args.rank % num_gpus)
391
+ # # args.gpu = args.rank % torch.cuda.device_count()
392
+
393
+ # args.dist_backend = 'nccl'
394
+ # print('| distributed init (rank {}): {}'.format(
395
+ # args.rank, args.dist_url), flush=True)
396
+ # torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
397
+ # world_size=args.world_size, rank=args.rank)
398
+ # torch.distributed.barrier()
399
+ # print('using distributed mode',args.rank, args.dist_url)
400
+ # setup_for_distributed(args.rank == 0)
401
+
402
+ # # export MASTER_ADDR=localhost
403
+ # export MASTER_PORT=5678
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/clip_tqn.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ # 加入父文件夹路径到sys.path中
4
+ sys.path.append(sys.path[0].replace('models', ''))
5
+
6
+ import re
7
+ import logging
8
+ import math
9
+ import json
10
+ import pathlib
11
+ import numpy as np
12
+ from copy import deepcopy
13
+ from pathlib import Path
14
+ from einops import rearrange
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ from typing import Tuple, Union, Callable, Optional
18
+
19
+
20
+ import torch
21
+ from torch import nn
22
+ import torch.nn.functional as F
23
+ import torchvision.models as models
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ from transformers import AutoModel,BertConfig,AutoTokenizer
27
+ # from pytorch_pretrained_vit import ViT
28
+
29
+ # from visualizer import get_local
30
+ from models.transformer_decoder import *
31
+
32
+ # from io import BytesIO
33
+ # from petrel_client.client import Client
34
+
35
+ # conf_path = '~/petreloss.conf'
36
+ # client = Client(conf_path)
37
+ from torch.autograd import Function
38
+ import timm
39
+
40
+ class ReverseLayerF(Function):
41
+
42
+ @staticmethod
43
+ def forward(ctx, x, alpha):
44
+ ctx.alpha = alpha
45
+ return x.view_as(x)
46
+
47
+ @staticmethod
48
+ def backward(ctx, grad_output):
49
+ return grad_output.neg() * ctx.alpha, None
50
+
51
+
52
+ class DomainClassifier(nn.Module):
53
+ '''一个单层分类器 带梯度反转层'''
54
+ def __init__(self, domain_nums=4, feature_dims=768):
55
+ super().__init__()
56
+ self.domain_nums = domain_nums
57
+ self.feature_dims = feature_dims
58
+ self.fc = nn.Linear(feature_dims, domain_nums)
59
+
60
+ def forward(self, x):
61
+ reverse_x = ReverseLayerF.apply(x, 1.0)
62
+ return self.fc(reverse_x)
63
+
64
+ class CLP_clinical(nn.Module):
65
+ def __init__(self,
66
+ bert_model_name: str,
67
+ embed_dim: int = 768,
68
+ freeze_layers:Union[Tuple[int, int], int] = None):
69
+ super().__init__()
70
+ self.bert_model = self._get_bert_basemodel(bert_model_name=bert_model_name, freeze_layers=freeze_layers)
71
+ self.mlp_embed = nn.Sequential(
72
+ nn.Linear(embed_dim, embed_dim),
73
+ nn.GELU(),
74
+ nn.Linear(embed_dim, embed_dim)
75
+ )
76
+ self.embed_dim = embed_dim
77
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
78
+ self.init_parameters()
79
+
80
+ def init_parameters(self):
81
+ nn.init.constant_(self.logit_scale, np.log(1 / 0.07))
82
+ for m in self.mlp_embed:
83
+ if isinstance(m, nn.Linear):
84
+ nn.init.normal_(m.weight, std=self.embed_dim ** -0.5)
85
+
86
+ def _get_bert_basemodel(self, bert_model_name, freeze_layers=None):#12
87
+ try:
88
+ print(bert_model_name)
89
+ config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)#bert-base-uncased
90
+ model = AutoModel.from_pretrained(bert_model_name, config=config)#, return_dict=True)
91
+ print("Text feature extractor:", bert_model_name)
92
+ print("bert encoder layers:",len(model.encoder.layer))
93
+ except:
94
+ raise ("Invalid model name. Check the config file and pass a BERT model from transformers lybrary")
95
+
96
+ if freeze_layers is not None:
97
+ for layer_idx in freeze_layers:
98
+ for param in list(model.encoder.layer[layer_idx].parameters()):
99
+ param.requires_grad = False
100
+ return model
101
+
102
+ def encode_text(self, text):
103
+ #input batch_size,token, return batch_size,dim
104
+ output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] )
105
+ last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2]
106
+ encode_out = self.mlp_embed(pooler_output)
107
+ # encode_out = pooler_output
108
+ return encode_out
109
+
110
+ def forward(self, text):
111
+ #input batch_size,token, return batch_size,dim
112
+ output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] )
113
+ last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2]
114
+ encode_out = self.mlp_embed(pooler_output)
115
+ # encode_out = pooler_output
116
+ return encode_out
117
+
118
+ # def forward(self,text1,text2):
119
+ # text1_features = self.encode_text(text1)
120
+ # text2_features = self.encode_text(text2)
121
+ # text1_features = F.normalize(text1_features, dim=-1)
122
+ # text2_features = F.normalize(text2_features, dim=-1)
123
+ # return text1_features, text2_features, self.logit_scale.exp()
124
+
125
+ class ModelRes(nn.Module):
126
+ def __init__(self, res_base_model):
127
+ super(ModelRes, self).__init__()
128
+ self.resnet_dict = {
129
+ "resnet50": models.resnet50(pretrained=True),
130
+ "resnet101": models.resnet101(pretrained=True),
131
+ "resnet152": models.resnet152(pretrained=True),
132
+ "resnet50_openai": None,
133
+ 'resnet101_openai': None,
134
+ 'resnet50x4_openai': None,
135
+ }
136
+ # "resnet50": models.resnet50(pretrained=True)}
137
+ self.resnet = self._get_res_basemodel(res_base_model)
138
+ # num_ftrs = int(self.resnet.fc.in_features/2)
139
+ # self.res_features = nn.Sequential(*list(self.resnet.children())[:-3]) 224
140
+ if 'openai' in res_base_model:
141
+ # 重新定义res_features
142
+ num_ftrs = int(self.resnet.attnpool.v_proj.in_features)
143
+ self.res_features = nn.Sequential(*list(self.resnet.children())[:-1])
144
+ else:
145
+ num_ftrs = int(self.resnet.fc.in_features)
146
+ self.res_features = nn.Sequential(*list(self.resnet.children())[:-2])
147
+ # here num_ftrs = 2048
148
+ self.res_l1 = nn.Linear(num_ftrs, num_ftrs)
149
+ self.res_l2 = nn.Linear(num_ftrs, 768)
150
+
151
+ def _get_res_basemodel(self, res_model_name):
152
+ try:
153
+ res_model = self.resnet_dict[res_model_name]
154
+ print("Image feature extractor:", res_model_name)
155
+ return res_model
156
+ except:
157
+ raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
158
+
159
+ def forward(self, img):
160
+ #return (batchsize, patch_num, dim)
161
+ batch_size = img.shape[0]
162
+ res_fea = self.res_features(img)
163
+ # return res_fea
164
+ # res_fea = F.adaptive_avg_pool2d(res_fea, (1, 1))
165
+ res_fea = rearrange(res_fea,'b d n1 n2 -> b (n1 n2) d')
166
+ h = rearrange(res_fea,'b n d -> (b n) d')
167
+ x = self.res_l1(h)
168
+ x = F.relu(x)
169
+ x = self.res_l2(x)
170
+ out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
171
+ out_pool = torch.mean(out_emb,dim=1)
172
+ return out_emb,out_pool
173
+
174
+
175
+ class ModelConvNeXt(nn.Module):
176
+ def __init__(self, convnext_base_model):
177
+ super(ModelConvNeXt, self).__init__()
178
+ self.convnext_dict = {"convnext-tiny": timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000),
179
+ "convnext-base": timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000),
180
+ }
181
+ convnext = self._get_convnext_basemodel(convnext_base_model)
182
+ num_ftrs = int(convnext.head.in_features)
183
+ self.conv_features = nn.Sequential(*list(convnext.children())[:-2])
184
+ self.conv_l1 = nn.Linear(num_ftrs, num_ftrs)
185
+ self.conv_l2 = nn.Linear(num_ftrs, 768)
186
+
187
+
188
+ def _get_convnext_basemodel(self, convnext_model_name):
189
+ try:
190
+ convnext_model = self.convnext_dict[convnext_model_name]
191
+ print("Image feature extractor:", convnext_model_name)
192
+ return convnext_model
193
+ except:
194
+ raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base")
195
+
196
+ def forward(self, img):
197
+ #return (batchsize, patch_num, dim)
198
+ batch_size = img.shape[0]
199
+ conv_fea = self.conv_features(img)
200
+ conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1))
201
+ conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d')
202
+ h = rearrange(conv_fea,'b n d -> (b n) d')
203
+ x = self.conv_l1(h)
204
+ x = F.relu(x)
205
+ x = self.conv_l2(x)
206
+ out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
207
+ out_pool = torch.mean(out_emb,dim=1)
208
+ return out_emb,out_pool
209
+
210
+
211
+ # class ModelConvNeXt(nn.Module):
212
+ # def __init__(self, convnext_base_model):
213
+ # super(ModelConvNeXt, self).__init__()
214
+ # self.convnext_dict = {"convnext-tiny": models.convnext_tiny(weights='ConvNeXt_Tiny_Weights.DEFAULT'),
215
+ # "convnext-small": models.convnext_small(weights='ConvNeXt_Small_Weights.DEFAULT'),
216
+ # "convnext-base": models.convnext_base(weights='ConvNeXt_Base_Weights.DEFAULT'),
217
+ # }
218
+ # convnext = self._get_convnext_basemodel(convnext_base_model)
219
+ # num_ftrs = int(convnext.classifier[-1].in_features)
220
+ # self.conv_features = nn.Sequential(*list(convnext.children())[:-2])
221
+ # self.conv_l1 = nn.Linear(num_ftrs, num_ftrs)
222
+ # self.conv_l2 = nn.Linear(num_ftrs, 768)
223
+
224
+
225
+ # def _get_convnext_basemodel(self, convnext_model_name):
226
+ # try:
227
+ # convnext_model = self.convnext_dict[convnext_model_name]
228
+ # print("Image feature extractor:", convnext_model_name)
229
+ # return convnext_model
230
+ # except:
231
+ # raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base")
232
+
233
+ # def forward(self, img):
234
+ # #return (batchsize, patch_num, dim)
235
+ # batch_size = img.shape[0]
236
+ # conv_fea = self.conv_features(img)
237
+ # conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1))
238
+ # conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d')
239
+ # h = rearrange(conv_fea,'b n d -> (b n) d')
240
+ # x = self.conv_l1(h)
241
+ # x = F.relu(x)
242
+ # x = self.conv_l2(x)
243
+ # out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
244
+ # out_pool = torch.mean(out_emb,dim=1)
245
+ # return out_emb,out_pool
246
+
247
+ # import open_clip
248
+ # class ModelCLIP(nn.Module):
249
+ # def __init__(self, clip_base_model):
250
+ # super(ModelCLIP, self).__init__()
251
+ # # 根据clip_base_model加载不同的模型
252
+ # if clip_base_model == 'openai_EVA02-B-16':
253
+ # model, _, preprocess = open_clip.create_model_and_transforms('EVA02-B-16', pretrained='merged2b_s8b_b131k')
254
+
255
+ # elif clip_base_model == 'openai_convnext_base_w':
256
+ # model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg')
257
+
258
+ # else:
259
+ # raise ("Invalid model name. Check the config file and pass one of: EVA02-B-16 or convnext_base_w")
260
+
261
+
262
+
263
+ class ModelEfficientV2(nn.Module):
264
+ def __init__(self, efficientv2_base_model):
265
+ super(ModelEfficientV2, self).__init__()
266
+ self.efficientv2_dict = {"efficientnet_v2_s": models.efficientnet_v2_s(weights='EfficientNet_V2_S_Weights.IMAGENET1K_V1'),}
267
+ self.efficientv2_model = self._get_efficientv2_basemodel(efficientv2_base_model)
268
+ num_ftrs = int(self.efficientv2_model.classifier[-1].in_features)
269
+ self.efficientv2_features = nn.Sequential(*list(self.efficientv2_model.children())[:-2])
270
+ self.efficientv2_l1 = nn.Linear(num_ftrs, num_ftrs)
271
+ self.efficientv2_l2 = nn.Linear(num_ftrs, 768)
272
+
273
+
274
+ def _get_efficientv2_basemodel(self, efficientv2_model_name):
275
+ try:
276
+ efficientv2_model = self.efficientv2_dict[efficientv2_model_name]
277
+ print("Image feature extractor:", efficientv2_model_name)
278
+ return efficientv2_model
279
+ except:
280
+ raise ("Invalid model name. Check the config file and pass one of: efficientnetv2_rw_s")
281
+
282
+ def forward(self, img):
283
+ batch_size = img.shape[0]
284
+ efficientv2_fea = self.efficientv2_features(img)
285
+ # efficientv2_fea = F.adaptive_avg_pool2d(efficientv2_fea, (1, 1))
286
+ # print(efficientv2_fea.shape)
287
+ efficientv2_fea = rearrange(efficientv2_fea,'b d n1 n2 -> b (n1 n2) d')
288
+ # print(efficientv2_fea.shape)
289
+ h = rearrange(efficientv2_fea,'b n d -> (b n) d')
290
+ # print(h.shape)
291
+ x = self.efficientv2_l1(h)
292
+ x = F.relu(x)
293
+ x = self.efficientv2_l2(x)
294
+ # print(x.shape)
295
+ out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
296
+ out_pool = torch.mean(out_emb,dim=1)
297
+ return out_emb,out_pool
298
+
299
+
300
+
301
+ class ModelDense(nn.Module):
302
+ def __init__(self, dense_base_model):
303
+ super(ModelDense, self).__init__()
304
+
305
+ self.densenet_dict = {"densenet121": models.densenet121(weights='DenseNet121_Weights.IMAGENET1K_V1'),
306
+ "densenet161": models.densenet161(weights='DenseNet161_Weights.IMAGENET1K_V1'),
307
+ "densenet201": models.densenet201(weights='DenseNet201_Weights.IMAGENET1K_V1'),}
308
+ self.densenet = self._get_dense_basemodel(dense_base_model)
309
+ num_ftrs = int(self.densenet.classifier.in_features)
310
+ self.dense_features = self.densenet.features
311
+ self.dense_l1 = nn.Linear(num_ftrs, num_ftrs)
312
+ self.dense_l2 = nn.Linear(num_ftrs, 768)
313
+
314
+ def _get_dense_basemodel(self, dense_base_model):
315
+ try:
316
+ dense_model = self.densenet_dict[dense_base_model]
317
+ print("Image feature extractor:", dense_base_model)
318
+ return dense_model
319
+ except:
320
+ raise ("Invalid model name. Check the config file and pass one of: densenet121 or densenet161")
321
+
322
+ def forward(self, img):
323
+ batch_size = img.shape[0]
324
+ dense_fea = self.dense_features(img)#N, 1024, 7,7
325
+ dense_fea = rearrange(dense_fea,'b d n1 n2 -> b (n1 n2) d')
326
+ h = rearrange(dense_fea,'b n d -> (b n) d')
327
+ x = self.dense_l1(h)
328
+ x = F.relu(x)
329
+ x = self.dense_l2(x)
330
+ out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
331
+ out_pool = torch.mean(out_emb,dim=1)
332
+ return out_emb,out_pool
333
+
334
+ class TQN_Model(nn.Module):
335
+ def __init__(self,
336
+ embed_dim: int = 768,
337
+ class_num: int = 1,
338
+ lam: list = [1, 0]
339
+ ):
340
+ super().__init__()
341
+ self.d_model = embed_dim
342
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
343
+ # decoder_layer = TransformerDecoderLayer(self.d_model, 4, 1024,
344
+ # 0.1, 'relu',normalize_before=True)
345
+ decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024,
346
+ 0.1, 'relu', True, lam)
347
+ self.decoder_norm = nn.LayerNorm(self.d_model)
348
+ # self.decoder = TransformerDecoder(decoder_layer, 4, self.decoder_norm,
349
+ # return_intermediate=False)
350
+ self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm,
351
+ return_intermediate=False)
352
+
353
+ self.dropout_feas = nn.Dropout(0.1)
354
+
355
+ # class_num = 2
356
+ self.mlp_head = nn.Sequential( # nn.LayerNorm(768),
357
+ nn.Linear(embed_dim, class_num)
358
+ )
359
+ self.apply(self._init_weights)
360
+
361
+ @staticmethod
362
+ def _init_weights(module):
363
+ if isinstance(module, nn.Linear):
364
+ module.weight.data.normal_(mean=0.0, std=0.02)
365
+
366
+ elif isinstance(module, nn.MultiheadAttention):
367
+ module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
368
+ module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
369
+
370
+ elif isinstance(module, nn.Embedding):
371
+ module.weight.data.normal_(mean=0.0, std=0.02)
372
+ if module.padding_idx is not None:
373
+ module.weight.data[module.padding_idx].zero_()
374
+
375
+ # def forward(self, image_features, text_features):
376
+ # #image_features (batch_size,patch_num,dim)
377
+ # #text_features (query_num,dim)
378
+ # batch_size = image_features.shape[0]
379
+ # image_features = image_features.transpose(0,1)
380
+ # text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
381
+ # image_features = self.decoder_norm(image_features)
382
+ # text_features = self.decoder_norm(text_features)
383
+
384
+ # # features = self.decoder(text_features, image_features,
385
+ # # memory_key_padding_mask=None, pos=None, query_pos=None)
386
+
387
+ # image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
388
+ # features = self.decoderV1(text_features, image_features, image_features_pool,
389
+ # memory_key_padding_mask=None, pos=None, query_pos=None)
390
+
391
+
392
+
393
+ # features = self.dropout_feas(features).transpose(0,1) #b,embed_dim
394
+ # out = self.mlp_head(features) #(batch_size, query_num)
395
+ # # out = out.squeeze(-1)
396
+ # return out
397
+
398
+ def forward(self, image_features, text_features, return_atten = False):
399
+ #image_features (batch_size,patch_num,dim)
400
+ #text_features (query_num,dim)
401
+ batch_size = image_features.shape[0]
402
+ image_features = image_features.transpose(0,1)
403
+ text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
404
+ image_features = self.decoder_norm(image_features)
405
+ text_features = self.decoder_norm(text_features)
406
+
407
+ image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
408
+ features,atten_map = self.decoderV1(text_features, image_features, image_features_pool,
409
+ memory_key_padding_mask=None, pos=None, query_pos=None)
410
+ features = self.dropout_feas(features).transpose(0,1) #b,embed_dim
411
+ out = self.mlp_head(features) #(batch_size, query_num)
412
+ if return_atten:
413
+ return out, atten_map
414
+ else:
415
+ return out
416
+
417
+
418
+
419
+ class TQN_Model_Ensemble(nn.Module):
420
+ def __init__(self,
421
+ embed_dim: int = 768,
422
+ class_num: int = 1,
423
+ lam: list = [1, 0]
424
+ ):
425
+ super().__init__()
426
+ self.d_model = embed_dim
427
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
428
+ decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024,
429
+ 0.1, 'relu', True, lam)
430
+ self.decoder_norm = nn.LayerNorm(self.d_model)
431
+ self.decoder_norm_1 = nn.LayerNorm(self.d_model)
432
+ self.decoder_norm_2 = nn.LayerNorm(self.d_model)
433
+ self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm,
434
+ return_intermediate=False)
435
+ self.decoderV1_1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_1,
436
+ return_intermediate=False)
437
+ self.decoderV1_2 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_2,
438
+ return_intermediate=False)
439
+
440
+ self.dropout_feas = nn.Dropout(0.1)
441
+
442
+ # class_num = 2
443
+ self.mlp_head = nn.Sequential(nn.Linear(embed_dim, class_num))
444
+ self.mlp_head_1 = nn.Sequential(nn.Linear(embed_dim, class_num))
445
+ self.mlp_head_2 = nn.Sequential(nn.Linear(embed_dim, class_num))
446
+ self.apply(self._init_weights)
447
+
448
+ @staticmethod
449
+ def _init_weights(module):
450
+ if isinstance(module, nn.Linear):
451
+ module.weight.data.normal_(mean=0.0, std=0.02)
452
+
453
+ elif isinstance(module, nn.MultiheadAttention):
454
+ module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
455
+ module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
456
+
457
+ elif isinstance(module, nn.Embedding):
458
+ module.weight.data.normal_(mean=0.0, std=0.02)
459
+ if module.padding_idx is not None:
460
+ module.weight.data[module.padding_idx].zero_()
461
+
462
+ def forward(self, image_features, text_features, return_atten = False):
463
+
464
+ batch_size = image_features.shape[0]
465
+ image_features = image_features.transpose(0,1)
466
+ text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
467
+ image_features = self.decoder_norm(image_features)
468
+ image_features_1 = self.decoder_norm_1(image_features)
469
+ image_features_2 = self.decoder_norm_2(image_features)
470
+
471
+ text_features = self.decoder_norm(text_features)
472
+ text_features_1 = self.decoder_norm_1(text_features)
473
+ text_features_2 = self.decoder_norm_2(text_features)
474
+
475
+ image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
476
+ image_features_pool_1 = torch.mean(image_features_1,dim=0).unsqueeze(0)
477
+ image_features_pool_2 = torch.mean(image_features_2,dim=0).unsqueeze(0)
478
+
479
+
480
+ features,atten_map = self.decoderV1(text_features, image_features, image_features_pool,
481
+ memory_key_padding_mask=None, pos=None, query_pos=None)
482
+ features = self.dropout_feas(features).transpose(0,1)
483
+ out = self.mlp_head(features)
484
+
485
+ features_1,atten_map_1 = self.decoderV1_1(text_features_1, image_features_1, image_features_pool_1,
486
+ memory_key_padding_mask=None, pos=None, query_pos=None)
487
+ features_1 = self.dropout_feas(features_1).transpose(0,1)
488
+ out_1 = self.mlp_head_1(features_1)
489
+
490
+ features_2,atten_map_2 = self.decoderV1_2(text_features_2, image_features_2, image_features_pool_2,
491
+ memory_key_padding_mask=None, pos=None, query_pos=None)
492
+ features_2 = self.dropout_feas(features_2).transpose(0,1)
493
+ out_2 = self.mlp_head_2(features_2)
494
+
495
+
496
+ out_stack = torch.stack([out, out_1, out_2])
497
+ out = torch.mean(out_stack, dim=0)
498
+
499
+
500
+
501
+
502
+
503
+
504
+ if return_atten:
505
+ return out, atten_map
506
+ else:
507
+ return out
508
+
509
+
510
+
511
+ # MIMIC时,batch_size=32, query_num=41, patch_num=256, dim=768
512
+ # img 256, 32, 768
513
+ # txt 1, 32, 768
514
+ # query41, 32, 768
515
+ # fts 41, 32, 768
516
+ # out 41, 32, 1
517
+ # 未经过sigmoid!计算loss时sigmoid!
518
+
519
+
520
+ if __name__ == "__main__":
521
+
522
+ #torch 1.10.2 to torch 1.12.1
523
+ #torchvision-0.11.3 to torchvision-0.13.1
524
+
525
+ # image = torch.randn(1, 3, 224, 224)
526
+ # image_encoder = ModelRes(res_base_model = 'resnet50')
527
+ # # image_encoder = ModelDense(dense_base_model = 'densenet121')
528
+ # # image_encoder = ModelViT(vit_base_model = 'vit_b_32')
529
+ # image_encoder(image)
530
+
531
+ # image = torch.randn(256, 1, 768)
532
+ # query = torch.randn(41, 768)
533
+ # model = TQN_Model()
534
+ # out = model(image, query)
535
+
536
+
537
+ # img = torch.randn(1,3,512,512)
538
+ img = torch.randn(2,3,224,224)
539
+ # model = ModelConvNeXt(convnext_base_model = 'convnext-base')
540
+ # model = ModelEfficientV2(efficientv2_base_model = 'efficientnet_v2_s')
541
+ model = ModelRes(res_base_model = 'resnet50_openai')
542
+ out_emb, out_pool = model(img)
543
+
544
+ print(out_emb.size(), out_pool.size())
545
+
546
+
models/resnet.py ADDED
@@ -0,0 +1,1382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Type, Any, Callable, Union, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from torchvision.transforms._presets import ImageClassification
9
+ from torchvision.utils import _log_api_usage_once
10
+ from torchvision.models._api import WeightsEnum, Weights
11
+ from torchvision.models._meta import _IMAGENET_CATEGORIES
12
+ from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param
13
+ import math
14
+ import torch.nn.functional as F
15
+
16
+ import random
17
+
18
+ from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
19
+ class LoRALayer(nn.Module):
20
+ """
21
+ Base lora class
22
+ """
23
+ def __init__(
24
+ self,
25
+ r,
26
+ lora_alpha,
27
+ ):
28
+ super().__init__()
29
+ self.r = r
30
+ self.lora_alpha = lora_alpha
31
+ # Mark the weight as unmerged
32
+ self.merged = False
33
+
34
+ def reset_parameters(self):
35
+ raise NotImplementedError
36
+
37
+ def train(self, mode:bool = True):
38
+ raise NotImplementedError
39
+
40
+ def eval(self):
41
+ raise NotImplementedError
42
+
43
+
44
+ class LoRALinear(LoRALayer):
45
+ def __init__(self, r, lora_alpha, linear_layer):
46
+ """
47
+ LoRA class for nn.Linear class
48
+ :param r: low rank dimension
49
+ :param lora_alpha: scaling factor
50
+ :param linear_layer: target nn.Linear layer for applying Lora
51
+ """
52
+ super().__init__(r, lora_alpha)
53
+ self.linear = linear_layer
54
+
55
+ in_features = self.linear.in_features
56
+ out_features = self.linear.out_features
57
+
58
+ # Lora configuration
59
+ self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features)))
60
+ self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r)))
61
+ self.scaling = self.lora_alpha / self.r
62
+ self.reset_parameters()
63
+
64
+ def reset_parameters(self):
65
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
66
+ nn.init.zeros_(self.lora_B)
67
+
68
+
69
+ def train(self, mode:bool = True):
70
+ self.linear.train(mode)
71
+ if self.merged:
72
+ self.linear.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
73
+ self.merged = False
74
+
75
+
76
+ def eval(self):
77
+ self.linear.eval()
78
+ if not self.merged:
79
+ self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
80
+ self.merged = True
81
+
82
+
83
+ def forward(self, x):
84
+ if not self.merged:
85
+ result = F.linear(x, self.linear.weight, bias=self.linear.bias)
86
+ out = (x @ self.lora_A.T @ self.lora_B.T)
87
+ result += out
88
+ return result
89
+ else:
90
+ return F.linear(x, self.linear.weight, bias=self.linear.bias)
91
+
92
+
93
+ # class LoraConv2d(LoRALayer):
94
+ # def __init__(self, r, lora_alpha, conv_layer):
95
+ # """
96
+ # LoRA class for nn.Conv2d class
97
+ # """
98
+ # super().__init__(r, lora_alpha)
99
+ # self.conv = conv_layer
100
+
101
+ # in_channels = self.conv.in_channels
102
+ # out_channels = self.conv.out_channels
103
+ # kernel_size = self.conv.kernel_size[0]
104
+
105
+ # # lora configuration
106
+ # self.lora_A = nn.Parameter(
107
+ # self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
108
+ # )
109
+ # self.lora_B = nn.Parameter(
110
+ # self.conv.weight.new_zeros((out_channels * kernel_size, r * kernel_size))
111
+ # )
112
+ # self.scaling = self.lora_alpha / self.r
113
+ # self.reset_parameters()
114
+
115
+ # def reset_parameters(self):
116
+ # nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
117
+ # nn.init.zeros_(self.lora_B)
118
+
119
+ # def train(self, mode: bool = True):
120
+ # self.conv.train(mode)
121
+ # if self.merged:
122
+ # # Make sure that the weights are not merged
123
+ # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
124
+ # self.merged = False
125
+
126
+ # def eval(self):
127
+ # self.conv.eval()
128
+ # if not self.merged:
129
+ # # Merge the weights and mark it
130
+ # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
131
+ # self.merged = True
132
+
133
+ # def forward(self, x):
134
+ # if not self.merged:
135
+ # return F.conv2d(
136
+ # x,
137
+ # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
138
+ # self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation, self.conv.groups
139
+ # )
140
+ # return self.conv(x)
141
+
142
+
143
+
144
+ class LoraConv2d(nn.Conv2d):
145
+ def __init__(
146
+ self,
147
+ r: int,
148
+ lora_alpha: float,
149
+ in_channels: int,
150
+ out_channels: int,
151
+ kernel_size: _size_2_t,
152
+ stride: _size_2_t = 1,
153
+ padding: Union[str, _size_2_t] = 0,
154
+ dilation: _size_2_t = 1,
155
+ groups: int = 1,
156
+ bias: bool = True,
157
+ padding_mode: str = 'zeros', # TODO: refine this type
158
+ device=None,
159
+ dtype=None
160
+ ):
161
+ """
162
+ LoRA class for nn.Conv2d class
163
+ """
164
+ super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
165
+ self.r = r
166
+ self.lora_alpha = lora_alpha
167
+ # lora configuration
168
+ self.lora_A = nn.Parameter(
169
+ self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
170
+ )
171
+ self.lora_B = nn.Parameter(
172
+ self.weight.new_zeros((out_channels * kernel_size, r * kernel_size))
173
+ )
174
+ self.scaling = self.lora_alpha / self.r
175
+ self.reset_parameters_lora()
176
+ self.merged = False
177
+ self.drop_lora_rate = 0.9
178
+
179
+ def reset_parameters_lora(self):
180
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
181
+ nn.init.zeros_(self.lora_B)
182
+
183
+ def train(self, mode: bool = True):
184
+ super().train(mode)
185
+ if self.merged:
186
+ # Make sure that the weights are not merged
187
+ self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
188
+ self.merged = False
189
+
190
+ def eval(self):
191
+ super().eval()
192
+ if not self.merged:
193
+ # Merge the weights and mark it
194
+ self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
195
+ self.merged = True
196
+
197
+ def forward(self, x):
198
+ # 产生一个随机数
199
+ # drop_rate = random.random()
200
+
201
+ # # 训练过程中以一定的概率不使用lora
202
+ # if drop_rate <= self.drop_lora_rate and self.training:
203
+ # return F.conv2d(
204
+ # x,
205
+ # self.weight,
206
+ # self.bias, self.stride, self.padding, self.dilation, self.groups
207
+ # )
208
+ # else:
209
+ return F.conv2d(
210
+ x,
211
+ self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
212
+ self.bias, self.stride, self.padding, self.dilation, self.groups
213
+ )
214
+
215
+
216
+
217
+ class MultiLoRALinear(LoRALayer):
218
+ def __init__(self, r, lora_alpha, linear_layer,lora_num):
219
+ """
220
+ LoRA class for nn.Linear class
221
+ :param r: low rank dimension
222
+ :param lora_alpha: scaling factor
223
+ :param linear_layer: target nn.Linear layer for applying Lora
224
+ """
225
+ super().__init__(r,lora_alpha)
226
+ self.linear = linear_layer
227
+ self.lora_num = lora_num
228
+ self.r_list = r
229
+
230
+ in_features = self.linear.in_features
231
+ out_features = self.linear.out_features
232
+
233
+ # Lora configuration
234
+ self.lora_A_list = nn.ParameterList([nn.Parameter(self.linear.weight.new_zeros((self.r_list[th], in_features))) for th in range(self.lora_num)])
235
+ self.lora_B_list = nn.ParameterList([nn.Parameter(self.linear.weight.new_zeros((out_features, self.r_list[th]))) for th in range(self.lora_num)])
236
+ # self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features)))
237
+ # self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r)))
238
+ self.scaling = [self.lora_alpha / self.r_list[th] for th in range(self.lora_num)]
239
+ self.reset_parameters()
240
+
241
+ def reset_parameters(self):
242
+ for th in range(self.lora_num):
243
+ nn.init.kaiming_uniform_(self.lora_A_list[th], a=math.sqrt(5))
244
+ nn.init.zeros_(self.lora_B_list[th])
245
+
246
+ def train(self, mode:bool = True):
247
+ self.linear.train(mode)
248
+
249
+ def eval(self):
250
+ self.linear.eval()
251
+
252
+ def forward(self, x, weights):
253
+ if not self.merged:
254
+ result = F.linear(x, self.linear.weight, bias=self.linear.bias) # (247, batch, 768)
255
+ out_stack = torch.stack([(x @ self.lora_A_list[th].T @ self.lora_B_list[th].T) * self.scaling[th] for th in range(self.lora_num)], dim=2) # (2353,16,3,768)
256
+ # (247, batch, lora_num, 768)
257
+ # weights = weights.unsqueeze(0).unsqueeze(-1)
258
+ # (1, batch, lora_num, 1)
259
+ # out = torch.sum(out_stack * weights,dim=2)
260
+ out = torch.sum(out_stack, dim=2)
261
+ # (247, batch, 768)
262
+ result += out
263
+ # (247, batch, 768)
264
+ return result
265
+ else:
266
+ return F.linear(x, self.linear.weight, bias=self.linear.bias)
267
+
268
+ class MultiLoraConv2d(LoRALayer):
269
+ def __init__(self, r, lora_alpha, conv_layer, num_task):
270
+ """
271
+ LoRA class for nn.Conv2d class
272
+ """
273
+ super().__init__(r, lora_alpha)
274
+ self.conv = conv_layer
275
+ self.num_task = num_task
276
+
277
+ in_channels = self.conv.in_channels
278
+ out_channels = self.conv.out_channels
279
+ kernel_size = self.conv.kernel_size[0]
280
+
281
+ # lora configuration
282
+ self.lora_A_list = nn.ParameterList([nn.Parameter(self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))) for th in range(num_task)])
283
+ self.lora_B_list = nn.ParameterList([nn.Parameter(self.conv.weight.new_zeros((out_channels * kernel_size, r * kernel_size))) for th in range(num_task)])
284
+
285
+ self.scaling = self.lora_alpha / self.r
286
+ self.reset_parameters()
287
+
288
+ self.merged = False
289
+ self.label_batch = None
290
+
291
+ def reset_parameters(self):
292
+ for th in range(self.num_task):
293
+ nn.init.kaiming_uniform_(self.lora_A_list[th], a=math.sqrt(5))
294
+ nn.init.zeros_(self.lora_B_list[th])
295
+
296
+ def train(self, mode: bool = True):
297
+ self.conv.train(mode)
298
+
299
+
300
+ def eval(self):
301
+ self.conv.eval()
302
+
303
+
304
+ def forward(self, input_x, alphas=None):
305
+ if not self.merged:
306
+ conv_weight_stack = torch.cat([(self.lora_B_list[th] @ self.lora_A_list[th]).view(self.conv.weight.shape).unsqueeze(0) * self.scaling for th in range(self.num_task)], dim=0)
307
+
308
+ if isinstance(input_x, dict):
309
+ # print('input is dict')
310
+ x, alphas = input_x[0], input_x[1]
311
+
312
+ else:
313
+ x = input_x
314
+ batch_size, c = x.shape[0], x.shape[1]
315
+ # print(alphas)
316
+ if alphas==None:
317
+ print('在lora_fast里才是none')
318
+ agg_weights = self.conv.weight + torch.sum(
319
+ torch.mul(conv_weight_stack.unsqueeze(0), alphas.view(batch_size, -1, 1, 1, 1, 1)), dim=1)
320
+
321
+ agg_weights = agg_weights.view(-1, *agg_weights.shape[-3:])
322
+ x_grouped = x.view(1, -1, *x.shape[-2:])
323
+
324
+ outputs = F.conv2d(x_grouped, agg_weights, self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation, groups=batch_size)
325
+ outputs = outputs.view(batch_size, -1, *outputs.shape[-2:])
326
+
327
+ return outputs
328
+ else:
329
+ return self.conv(x)
330
+
331
+ def merged_weight(self, th): # only for test
332
+ self.conv.weight.data += (self.lora_B_list[th] @ self.lora_A_list[th]).view(self.conv.weight.shape) * self.scaling
333
+ self.merged = True
334
+
335
+
336
+
337
+ __all__ = [
338
+ "ResNet",
339
+ "ResNet18_Weights",
340
+ "ResNet34_Weights",
341
+ "ResNet50_Weights",
342
+ "ResNet101_Weights",
343
+ "ResNet152_Weights",
344
+ "ResNeXt50_32X4D_Weights",
345
+ "ResNeXt101_32X8D_Weights",
346
+ "ResNeXt101_64X4D_Weights",
347
+ "Wide_ResNet50_2_Weights",
348
+ "Wide_ResNet101_2_Weights",
349
+ "resnet18",
350
+ "resnet34",
351
+ "resnet50",
352
+ "resnet101",
353
+ "resnet152",
354
+ "resnext50_32x4d",
355
+ "resnext101_32x8d",
356
+ "resnext101_64x4d",
357
+ "wide_resnet50_2",
358
+ "wide_resnet101_2",
359
+ ]
360
+
361
+
362
+ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
363
+ """3x3 convolution with padding"""
364
+ return nn.Conv2d(
365
+ in_planes,
366
+ out_planes,
367
+ kernel_size=3,
368
+ stride=stride,
369
+ padding=dilation,
370
+ groups=groups,
371
+ bias=False,
372
+ dilation=dilation,
373
+ )
374
+
375
+
376
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
377
+ """1x1 convolution"""
378
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
379
+
380
+ def conv3x3_lora(r: int, lora_alpha: float, in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
381
+ """3x3 convolution with padding"""
382
+ return LoraConv2d(
383
+ r,lora_alpha,
384
+ in_planes,
385
+ out_planes,
386
+ kernel_size=3,
387
+ stride=stride,
388
+ padding=dilation,
389
+ groups=groups,
390
+ bias=False,
391
+ dilation=dilation,
392
+ )
393
+
394
+ def conv1x1_lora(r: int, lora_alpha: float, in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
395
+ """1x1 convolution"""
396
+ return LoraConv2d(r, lora_alpha, in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
397
+
398
+
399
+
400
+ class BasicBlock_Lora(nn.Module):
401
+ expansion: int = 1
402
+
403
+ def __init__(
404
+ self,
405
+ inplanes: int,
406
+ planes: int,
407
+ r: int,
408
+ lora_alpha: float,
409
+ stride: int = 1,
410
+ downsample: Optional[nn.Module] = None,
411
+ groups: int = 1,
412
+ base_width: int = 64,
413
+ dilation: int = 1,
414
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
415
+ ) -> None:
416
+ super().__init__()
417
+ if norm_layer is None:
418
+ norm_layer = nn.BatchNorm2d
419
+ if groups != 1 or base_width != 64:
420
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
421
+ if dilation > 1:
422
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
423
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
424
+ self.conv1 = conv3x3_lora(r, lora_alpha, inplanes, planes, stride)
425
+ self.bn1 = norm_layer(planes)
426
+ self.relu = nn.ReLU(inplace=True)
427
+ self.conv2 = conv3x3_lora(r, lora_alpha, planes, planes)
428
+ self.bn2 = norm_layer(planes)
429
+ self.downsample = downsample
430
+ self.stride = stride
431
+
432
+ def forward(self, x: Tensor) -> Tensor:
433
+ identity = x
434
+
435
+ out = self.conv1(x)
436
+ out = self.bn1(out)
437
+ out = self.relu(out)
438
+
439
+ out = self.conv2(out)
440
+ out = self.bn2(out)
441
+
442
+ if self.downsample is not None:
443
+ identity = self.downsample(x)
444
+
445
+ out += identity
446
+ out = self.relu(out)
447
+
448
+ return out
449
+
450
+ class BasicBlock(nn.Module):
451
+ expansion: int = 1
452
+
453
+ def __init__(
454
+ self,
455
+ inplanes: int,
456
+ planes: int,
457
+ stride: int = 1,
458
+ downsample: Optional[nn.Module] = None,
459
+ groups: int = 1,
460
+ base_width: int = 64,
461
+ dilation: int = 1,
462
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
463
+ ) -> None:
464
+ super().__init__()
465
+ if norm_layer is None:
466
+ norm_layer = nn.BatchNorm2d
467
+ if groups != 1 or base_width != 64:
468
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
469
+ if dilation > 1:
470
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
471
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
472
+ self.conv1 = conv3x3(inplanes, planes, stride)
473
+ self.bn1 = norm_layer(planes)
474
+ self.relu = nn.ReLU(inplace=True)
475
+ self.conv2 = conv3x3(planes, planes)
476
+ self.bn2 = norm_layer(planes)
477
+ self.downsample = downsample
478
+ self.stride = stride
479
+
480
+ def forward(self, x: Tensor) -> Tensor:
481
+ identity = x
482
+
483
+ out = self.conv1(x)
484
+ out = self.bn1(out)
485
+ out = self.relu(out)
486
+
487
+ out = self.conv2(out)
488
+ out = self.bn2(out)
489
+
490
+ if self.downsample is not None:
491
+ identity = self.downsample(x)
492
+
493
+ out += identity
494
+ out = self.relu(out)
495
+
496
+ return out
497
+
498
+
499
+ class Bottleneck(nn.Module):
500
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
501
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
502
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
503
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
504
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
505
+
506
+ expansion: int = 4
507
+
508
+ def __init__(
509
+ self,
510
+ inplanes: int,
511
+ planes: int,
512
+ stride: int = 1,
513
+ downsample: Optional[nn.Module] = None,
514
+ groups: int = 1,
515
+ base_width: int = 64,
516
+ dilation: int = 1,
517
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
518
+ ) -> None:
519
+ super().__init__()
520
+ if norm_layer is None:
521
+ norm_layer = nn.BatchNorm2d
522
+ width = int(planes * (base_width / 64.0)) * groups
523
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
524
+ self.conv1 = conv1x1(inplanes, width)
525
+ self.bn1 = norm_layer(width)
526
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
527
+ self.bn2 = norm_layer(width)
528
+ self.conv3 = conv1x1(width, planes * self.expansion)
529
+ self.bn3 = norm_layer(planes * self.expansion)
530
+ self.relu = nn.ReLU(inplace=True)
531
+ self.downsample = downsample
532
+ self.stride = stride
533
+
534
+ def forward(self, x: Tensor) -> Tensor:
535
+ identity = x
536
+
537
+ out = self.conv1(x)
538
+ out = self.bn1(out)
539
+ out = self.relu(out)
540
+
541
+ out = self.conv2(out)
542
+ out = self.bn2(out)
543
+ out = self.relu(out)
544
+
545
+ out = self.conv3(out)
546
+ out = self.bn3(out)
547
+
548
+ if self.downsample is not None:
549
+ identity = self.downsample(x)
550
+
551
+ out += identity
552
+ out = self.relu(out)
553
+
554
+ return out
555
+
556
+ class Bottleneck_Lora(nn.Module):
557
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
558
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
559
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
560
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
561
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
562
+
563
+ expansion: int = 4
564
+
565
+ def __init__(
566
+ self,
567
+ inplanes: int,
568
+ planes: int,
569
+ r: int,
570
+ lora_alpha: float,
571
+ stride: int = 1,
572
+ downsample: Optional[nn.Module] = None,
573
+ groups: int = 1,
574
+ base_width: int = 64,
575
+ dilation: int = 1,
576
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
577
+ ) -> None:
578
+ super().__init__()
579
+ if norm_layer is None:
580
+ norm_layer = nn.BatchNorm2d
581
+ width = int(planes * (base_width / 64.0)) * groups
582
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
583
+ self.conv1 = conv1x1_lora(r, lora_alpha, inplanes, width)
584
+ self.bn1 = norm_layer(width)
585
+ self.conv2 = conv3x3_lora(r, lora_alpha, width, width, stride, groups, dilation)
586
+ self.bn2 = norm_layer(width)
587
+ self.conv3 = conv1x1_lora(r, lora_alpha, width, planes * self.expansion)
588
+ self.bn3 = norm_layer(planes * self.expansion)
589
+ self.relu = nn.ReLU(inplace=True)
590
+ self.downsample = downsample
591
+ self.stride = stride
592
+
593
+ def forward(self, x: Tensor) -> Tensor:
594
+ identity = x
595
+
596
+ out = self.conv1(x)
597
+ out = self.bn1(out)
598
+ out = self.relu(out)
599
+
600
+ out = self.conv2(out)
601
+ out = self.bn2(out)
602
+ out = self.relu(out)
603
+
604
+ out = self.conv3(out)
605
+ out = self.bn3(out)
606
+
607
+ if self.downsample is not None:
608
+ identity = self.downsample(x)
609
+
610
+ out += identity
611
+ out = self.relu(out)
612
+
613
+ return out
614
+
615
+
616
+ class ResNet(nn.Module):
617
+ def __init__(
618
+ self,
619
+ block: Type[Union[BasicBlock, Bottleneck]],
620
+ layers: List[int],
621
+ num_classes: int = 1000,
622
+ zero_init_residual: bool = False,
623
+ groups: int = 1,
624
+ width_per_group: int = 64,
625
+ replace_stride_with_dilation: Optional[List[bool]] = None,
626
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
627
+ ) -> None:
628
+ super().__init__()
629
+ _log_api_usage_once(self)
630
+ if norm_layer is None:
631
+ norm_layer = nn.BatchNorm2d
632
+ self._norm_layer = norm_layer
633
+
634
+ self.inplanes = 64
635
+ self.dilation = 1
636
+ if replace_stride_with_dilation is None:
637
+ # each element in the tuple indicates if we should replace
638
+ # the 2x2 stride with a dilated convolution instead
639
+ replace_stride_with_dilation = [False, False, False]
640
+ if len(replace_stride_with_dilation) != 3:
641
+ raise ValueError(
642
+ "replace_stride_with_dilation should be None "
643
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
644
+ )
645
+ self.groups = groups
646
+ self.base_width = width_per_group
647
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
648
+ self.bn1 = norm_layer(self.inplanes)
649
+ self.relu = nn.ReLU(inplace=True)
650
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
651
+ self.layer1 = self._make_layer(block, 64, layers[0])
652
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
653
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
654
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
655
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
656
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
657
+
658
+ for m in self.modules():
659
+ if isinstance(m, nn.Conv2d):
660
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
661
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
662
+ nn.init.constant_(m.weight, 1)
663
+ nn.init.constant_(m.bias, 0)
664
+
665
+ # Zero-initialize the last BN in each residual branch,
666
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
667
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
668
+ if zero_init_residual:
669
+ for m in self.modules():
670
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
671
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
672
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
673
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
674
+
675
+ def _make_layer(
676
+ self,
677
+ block: Type[Union[BasicBlock, Bottleneck]],
678
+ planes: int,
679
+ blocks: int,
680
+ stride: int = 1,
681
+ dilate: bool = False,
682
+ ) -> nn.Sequential:
683
+ norm_layer = self._norm_layer
684
+ downsample = None
685
+ previous_dilation = self.dilation
686
+ if dilate:
687
+ self.dilation *= stride
688
+ stride = 1
689
+ if stride != 1 or self.inplanes != planes * block.expansion:
690
+ downsample = nn.Sequential(
691
+ conv1x1(self.inplanes, planes * block.expansion, stride),
692
+ norm_layer(planes * block.expansion),
693
+ )
694
+
695
+ layers = []
696
+ layers.append(
697
+ block(
698
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
699
+ )
700
+ )
701
+ self.inplanes = planes * block.expansion
702
+ for _ in range(1, blocks):
703
+ layers.append(
704
+ block(
705
+ self.inplanes,
706
+ planes,
707
+ groups=self.groups,
708
+ base_width=self.base_width,
709
+ dilation=self.dilation,
710
+ norm_layer=norm_layer,
711
+ )
712
+ )
713
+
714
+ return nn.Sequential(*layers)
715
+
716
+ def _forward_impl(self, x: Tensor) -> Tensor:
717
+ # See note [TorchScript super()]
718
+ x = self.conv1(x)
719
+ x = self.bn1(x)
720
+ x = self.relu(x)
721
+ x = self.maxpool(x)
722
+
723
+ x = self.layer1(x)
724
+ x = self.layer2(x)
725
+ x = self.layer3(x)
726
+ x = self.layer4(x)
727
+
728
+ x = self.avgpool(x)
729
+ x = torch.flatten(x, 1)
730
+ x = self.fc(x)
731
+
732
+ return x
733
+
734
+ def forward(self, x: Tensor) -> Tensor:
735
+ return self._forward_impl(x)
736
+
737
+ class ResNet_Lora(nn.Module):
738
+ def __init__(
739
+ self,
740
+ block: Type[Union[BasicBlock, Bottleneck]],
741
+ layers: List[int],
742
+ r: int,
743
+ lora_alpha: float,
744
+ num_classes: int = 1000,
745
+ zero_init_residual: bool = False,
746
+ groups: int = 1,
747
+ width_per_group: int = 64,
748
+ replace_stride_with_dilation: Optional[List[bool]] = None,
749
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
750
+ ) -> None:
751
+ super().__init__()
752
+ _log_api_usage_once(self)
753
+ if norm_layer is None:
754
+ norm_layer = nn.BatchNorm2d
755
+ self._norm_layer = norm_layer
756
+
757
+ self.inplanes = 64
758
+ self.dilation = 1
759
+ if replace_stride_with_dilation is None:
760
+ # each element in the tuple indicates if we should replace
761
+ # the 2x2 stride with a dilated convolution instead
762
+ replace_stride_with_dilation = [False, False, False]
763
+ if len(replace_stride_with_dilation) != 3:
764
+ raise ValueError(
765
+ "replace_stride_with_dilation should be None "
766
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
767
+ )
768
+ self.groups = groups
769
+ self.base_width = width_per_group
770
+ self.r = r
771
+ self.lora_alpha = lora_alpha
772
+ self.conv1 = LoraConv2d(self.r, self.lora_alpha, 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
773
+ self.bn1 = norm_layer(self.inplanes)
774
+ self.relu = nn.ReLU(inplace=True)
775
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
776
+ self.layer1 = self._make_layer(block, 64, layers[0])
777
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
778
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
779
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
780
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
781
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
782
+
783
+ for m in self.modules():
784
+ if isinstance(m, nn.Conv2d):
785
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
786
+ elif isinstance(m, LoraConv2d):
787
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
788
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
789
+ nn.init.constant_(m.weight, 1)
790
+ nn.init.constant_(m.bias, 0)
791
+
792
+ # Zero-initialize the last BN in each residual branch,
793
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
794
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
795
+ if zero_init_residual:
796
+ for m in self.modules():
797
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
798
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
799
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
800
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
801
+
802
+ def _make_layer(
803
+ self,
804
+ block: Type[Union[BasicBlock, Bottleneck]],
805
+ planes: int,
806
+ blocks: int,
807
+ stride: int = 1,
808
+ dilate: bool = False,
809
+ ) -> nn.Sequential:
810
+ norm_layer = self._norm_layer
811
+ downsample = None
812
+ previous_dilation = self.dilation
813
+ if dilate:
814
+ self.dilation *= stride
815
+ stride = 1
816
+ if stride != 1 or self.inplanes != planes * block.expansion:
817
+ downsample = nn.Sequential(
818
+ conv1x1_lora(self.r, self.lora_alpha, self.inplanes, planes * block.expansion, stride),
819
+ norm_layer(planes * block.expansion),
820
+ )
821
+
822
+ layers = []
823
+ layers.append(
824
+ block(
825
+ self.inplanes, planes, self.r, self.lora_alpha, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
826
+ )
827
+ )
828
+ self.inplanes = planes * block.expansion
829
+ for _ in range(1, blocks):
830
+ layers.append(
831
+ block(
832
+ self.inplanes,
833
+ planes,
834
+ self.r,
835
+ self.lora_alpha,
836
+ groups=self.groups,
837
+ base_width=self.base_width,
838
+ dilation=self.dilation,
839
+ norm_layer=norm_layer,
840
+ )
841
+ )
842
+
843
+ return nn.Sequential(*layers)
844
+
845
+ def _forward_impl(self, x: Tensor) -> Tensor:
846
+ # See note [TorchScript super()]
847
+ x = self.conv1(x)
848
+ x = self.bn1(x)
849
+ x = self.relu(x)
850
+ x = self.maxpool(x)
851
+
852
+ x = self.layer1(x)
853
+ x = self.layer2(x)
854
+ x = self.layer3(x)
855
+ x = self.layer4(x)
856
+
857
+ x = self.avgpool(x)
858
+ x = torch.flatten(x, 1)
859
+ x = self.fc(x)
860
+
861
+ return x
862
+
863
+ def forward(self, x: Tensor) -> Tensor:
864
+ return self._forward_impl(x)
865
+
866
+
867
+
868
+ def _resnet(
869
+ block: Type[Union[BasicBlock, Bottleneck]],
870
+ layers: List[int],
871
+ weights: Optional[WeightsEnum],
872
+ progress: bool,
873
+ **kwargs: Any,
874
+ ) -> ResNet:
875
+ if weights is not None:
876
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
877
+
878
+ model = ResNet(block, layers, **kwargs)
879
+
880
+ if weights is not None:
881
+ model.load_state_dict(weights.get_state_dict(progress=progress))
882
+
883
+ return model
884
+
885
+
886
+ def _resnet_lora(
887
+ block: Type[Union[BasicBlock, Bottleneck]],
888
+ layers: List[int],
889
+ r: int,
890
+ lora_alpha: float,
891
+ weights: Optional[WeightsEnum],
892
+ progress: bool,
893
+ **kwargs: Any,
894
+ ) -> ResNet_Lora:
895
+ if weights is not None:
896
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
897
+
898
+ model = ResNet_Lora(block, layers, r, lora_alpha, **kwargs)
899
+ if weights is not None:
900
+ missing_keys, unexpected_keys = model.load_state_dict(weights.get_state_dict(progress=progress), strict=False)
901
+
902
+ for key_name in missing_keys:
903
+ if 'lora_A' in key_name or 'lora_B' in key_name:
904
+ pass
905
+ else:
906
+ raise ValueError(f'{key_name} in missing keys')
907
+
908
+ if unexpected_keys != []:
909
+ raise ValueError(f'Have unexpected keys {unexpected_keys}')
910
+
911
+ return model
912
+
913
+ _COMMON_META = {
914
+ "min_size": (1, 1),
915
+ "categories": _IMAGENET_CATEGORIES,
916
+ }
917
+
918
+
919
+ class ResNet18_Weights(WeightsEnum):
920
+ IMAGENET1K_V1 = Weights(
921
+ url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
922
+ transforms=partial(ImageClassification, crop_size=224),
923
+ meta={
924
+ **_COMMON_META,
925
+ "num_params": 11689512,
926
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
927
+ "_metrics": {
928
+ "ImageNet-1K": {
929
+ "acc@1": 69.758,
930
+ "acc@5": 89.078,
931
+ }
932
+ },
933
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
934
+ },
935
+ )
936
+ DEFAULT = IMAGENET1K_V1
937
+
938
+
939
+ class ResNet34_Weights(WeightsEnum):
940
+ IMAGENET1K_V1 = Weights(
941
+ url="https://download.pytorch.org/models/resnet34-b627a593.pth",
942
+ transforms=partial(ImageClassification, crop_size=224),
943
+ meta={
944
+ **_COMMON_META,
945
+ "num_params": 21797672,
946
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
947
+ "_metrics": {
948
+ "ImageNet-1K": {
949
+ "acc@1": 73.314,
950
+ "acc@5": 91.420,
951
+ }
952
+ },
953
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
954
+ },
955
+ )
956
+ DEFAULT = IMAGENET1K_V1
957
+
958
+
959
+ class ResNet50_Weights(WeightsEnum):
960
+ IMAGENET1K_V1 = Weights(
961
+ url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
962
+ transforms=partial(ImageClassification, crop_size=224),
963
+ meta={
964
+ **_COMMON_META,
965
+ "num_params": 25557032,
966
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
967
+ "_metrics": {
968
+ "ImageNet-1K": {
969
+ "acc@1": 76.130,
970
+ "acc@5": 92.862,
971
+ }
972
+ },
973
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
974
+ },
975
+ )
976
+ IMAGENET1K_V2 = Weights(
977
+ url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
978
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
979
+ meta={
980
+ **_COMMON_META,
981
+ "num_params": 25557032,
982
+ "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
983
+ "_metrics": {
984
+ "ImageNet-1K": {
985
+ "acc@1": 80.858,
986
+ "acc@5": 95.434,
987
+ }
988
+ },
989
+ "_docs": """
990
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
991
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
992
+ """,
993
+ },
994
+ )
995
+ DEFAULT = IMAGENET1K_V2
996
+
997
+
998
+ class ResNet101_Weights(WeightsEnum):
999
+ IMAGENET1K_V1 = Weights(
1000
+ url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
1001
+ transforms=partial(ImageClassification, crop_size=224),
1002
+ meta={
1003
+ **_COMMON_META,
1004
+ "num_params": 44549160,
1005
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
1006
+ "_metrics": {
1007
+ "ImageNet-1K": {
1008
+ "acc@1": 77.374,
1009
+ "acc@5": 93.546,
1010
+ }
1011
+ },
1012
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1013
+ },
1014
+ )
1015
+ IMAGENET1K_V2 = Weights(
1016
+ url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
1017
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
1018
+ meta={
1019
+ **_COMMON_META,
1020
+ "num_params": 44549160,
1021
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1022
+ "_metrics": {
1023
+ "ImageNet-1K": {
1024
+ "acc@1": 81.886,
1025
+ "acc@5": 95.780,
1026
+ }
1027
+ },
1028
+ "_docs": """
1029
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
1030
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
1031
+ """,
1032
+ },
1033
+ )
1034
+ DEFAULT = IMAGENET1K_V2
1035
+
1036
+
1037
+ class ResNet152_Weights(WeightsEnum):
1038
+ IMAGENET1K_V1 = Weights(
1039
+ url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
1040
+ transforms=partial(ImageClassification, crop_size=224),
1041
+ meta={
1042
+ **_COMMON_META,
1043
+ "num_params": 60192808,
1044
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
1045
+ "_metrics": {
1046
+ "ImageNet-1K": {
1047
+ "acc@1": 78.312,
1048
+ "acc@5": 94.046,
1049
+ }
1050
+ },
1051
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1052
+ },
1053
+ )
1054
+ IMAGENET1K_V2 = Weights(
1055
+ url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
1056
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
1057
+ meta={
1058
+ **_COMMON_META,
1059
+ "num_params": 60192808,
1060
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1061
+ "_metrics": {
1062
+ "ImageNet-1K": {
1063
+ "acc@1": 82.284,
1064
+ "acc@5": 96.002,
1065
+ }
1066
+ },
1067
+ "_docs": """
1068
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
1069
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
1070
+ """,
1071
+ },
1072
+ )
1073
+ DEFAULT = IMAGENET1K_V2
1074
+
1075
+
1076
+ class ResNeXt50_32X4D_Weights(WeightsEnum):
1077
+ IMAGENET1K_V1 = Weights(
1078
+ url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
1079
+ transforms=partial(ImageClassification, crop_size=224),
1080
+ meta={
1081
+ **_COMMON_META,
1082
+ "num_params": 25028904,
1083
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
1084
+ "_metrics": {
1085
+ "ImageNet-1K": {
1086
+ "acc@1": 77.618,
1087
+ "acc@5": 93.698,
1088
+ }
1089
+ },
1090
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1091
+ },
1092
+ )
1093
+ IMAGENET1K_V2 = Weights(
1094
+ url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
1095
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
1096
+ meta={
1097
+ **_COMMON_META,
1098
+ "num_params": 25028904,
1099
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1100
+ "_metrics": {
1101
+ "ImageNet-1K": {
1102
+ "acc@1": 81.198,
1103
+ "acc@5": 95.340,
1104
+ }
1105
+ },
1106
+ "_docs": """
1107
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
1108
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
1109
+ """,
1110
+ },
1111
+ )
1112
+ DEFAULT = IMAGENET1K_V2
1113
+
1114
+
1115
+ class ResNeXt101_32X8D_Weights(WeightsEnum):
1116
+ IMAGENET1K_V1 = Weights(
1117
+ url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
1118
+ transforms=partial(ImageClassification, crop_size=224),
1119
+ meta={
1120
+ **_COMMON_META,
1121
+ "num_params": 88791336,
1122
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
1123
+ "_metrics": {
1124
+ "ImageNet-1K": {
1125
+ "acc@1": 79.312,
1126
+ "acc@5": 94.526,
1127
+ }
1128
+ },
1129
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1130
+ },
1131
+ )
1132
+ IMAGENET1K_V2 = Weights(
1133
+ url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
1134
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
1135
+ meta={
1136
+ **_COMMON_META,
1137
+ "num_params": 88791336,
1138
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
1139
+ "_metrics": {
1140
+ "ImageNet-1K": {
1141
+ "acc@1": 82.834,
1142
+ "acc@5": 96.228,
1143
+ }
1144
+ },
1145
+ "_docs": """
1146
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
1147
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
1148
+ """,
1149
+ },
1150
+ )
1151
+ DEFAULT = IMAGENET1K_V2
1152
+
1153
+
1154
+ class ResNeXt101_64X4D_Weights(WeightsEnum):
1155
+ IMAGENET1K_V1 = Weights(
1156
+ url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
1157
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
1158
+ meta={
1159
+ **_COMMON_META,
1160
+ "num_params": 83455272,
1161
+ "recipe": "https://github.com/pytorch/vision/pull/5935",
1162
+ "_metrics": {
1163
+ "ImageNet-1K": {
1164
+ "acc@1": 83.246,
1165
+ "acc@5": 96.454,
1166
+ }
1167
+ },
1168
+ "_docs": """
1169
+ These weights were trained from scratch by using TorchVision's `new training recipe
1170
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
1171
+ """,
1172
+ },
1173
+ )
1174
+ DEFAULT = IMAGENET1K_V1
1175
+
1176
+
1177
+ class Wide_ResNet50_2_Weights(WeightsEnum):
1178
+ IMAGENET1K_V1 = Weights(
1179
+ url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
1180
+ transforms=partial(ImageClassification, crop_size=224),
1181
+ meta={
1182
+ **_COMMON_META,
1183
+ "num_params": 68883240,
1184
+ "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
1185
+ "_metrics": {
1186
+ "ImageNet-1K": {
1187
+ "acc@1": 78.468,
1188
+ "acc@5": 94.086,
1189
+ }
1190
+ },
1191
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1192
+ },
1193
+ )
1194
+ IMAGENET1K_V2 = Weights(
1195
+ url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
1196
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
1197
+ meta={
1198
+ **_COMMON_META,
1199
+ "num_params": 68883240,
1200
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
1201
+ "_metrics": {
1202
+ "ImageNet-1K": {
1203
+ "acc@1": 81.602,
1204
+ "acc@5": 95.758,
1205
+ }
1206
+ },
1207
+ "_docs": """
1208
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
1209
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
1210
+ """,
1211
+ },
1212
+ )
1213
+ DEFAULT = IMAGENET1K_V2
1214
+
1215
+
1216
+ class Wide_ResNet101_2_Weights(WeightsEnum):
1217
+ IMAGENET1K_V1 = Weights(
1218
+ url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
1219
+ transforms=partial(ImageClassification, crop_size=224),
1220
+ meta={
1221
+ **_COMMON_META,
1222
+ "num_params": 126886696,
1223
+ "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
1224
+ "_metrics": {
1225
+ "ImageNet-1K": {
1226
+ "acc@1": 78.848,
1227
+ "acc@5": 94.284,
1228
+ }
1229
+ },
1230
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1231
+ },
1232
+ )
1233
+ IMAGENET1K_V2 = Weights(
1234
+ url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
1235
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
1236
+ meta={
1237
+ **_COMMON_META,
1238
+ "num_params": 126886696,
1239
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1240
+ "_metrics": {
1241
+ "ImageNet-1K": {
1242
+ "acc@1": 82.510,
1243
+ "acc@5": 96.020,
1244
+ }
1245
+ },
1246
+ "_docs": """
1247
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
1248
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
1249
+ """,
1250
+ },
1251
+ )
1252
+ DEFAULT = IMAGENET1K_V2
1253
+
1254
+
1255
+ @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
1256
+ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1257
+ weights = ResNet18_Weights.verify(weights)
1258
+
1259
+ return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
1260
+
1261
+
1262
+ @handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
1263
+ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1264
+ weights = ResNet34_Weights.verify(weights)
1265
+
1266
+ return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
1267
+
1268
+
1269
+ @handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
1270
+ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1271
+ weights = ResNet50_Weights.verify(weights)
1272
+
1273
+ return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
1274
+
1275
+ @handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
1276
+ def resnet50_lora(*, r: int, lora_alpha: float, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1277
+ weights = ResNet50_Weights.verify(weights)
1278
+
1279
+ return _resnet_lora(Bottleneck_Lora, [3, 4, 6, 3], r, lora_alpha, weights, progress, **kwargs)
1280
+
1281
+ @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
1282
+ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1283
+ weights = ResNet101_Weights.verify(weights)
1284
+
1285
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
1286
+
1287
+ @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
1288
+ def resnet101_lora(*, r: int, lora_alpha: float, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1289
+ weights = ResNet101_Weights.verify(weights)
1290
+
1291
+ return _resnet_lora(Bottleneck_Lora, [3, 4, 23, 3], r, lora_alpha, weights, progress, **kwargs)
1292
+
1293
+
1294
+ @handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
1295
+ def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1296
+ weights = ResNet152_Weights.verify(weights)
1297
+
1298
+ return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
1299
+
1300
+ @handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
1301
+ def resnet152_lora(*, r: int, lora_alpha: float, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
1302
+ weights = ResNet152_Weights.verify(weights)
1303
+
1304
+ return _resnet_lora(Bottleneck_Lora, [3, 8, 36, 3], r, lora_alpha, weights, progress, **kwargs)
1305
+
1306
+
1307
+ @handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
1308
+ def resnext50_32x4d(
1309
+ *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
1310
+ ) -> ResNet:
1311
+ weights = ResNeXt50_32X4D_Weights.verify(weights)
1312
+
1313
+ _ovewrite_named_param(kwargs, "groups", 32)
1314
+ _ovewrite_named_param(kwargs, "width_per_group", 4)
1315
+ return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
1316
+
1317
+
1318
+ @handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
1319
+ def resnext101_32x8d(
1320
+ *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
1321
+ ) -> ResNet:
1322
+ weights = ResNeXt101_32X8D_Weights.verify(weights)
1323
+
1324
+ _ovewrite_named_param(kwargs, "groups", 32)
1325
+ _ovewrite_named_param(kwargs, "width_per_group", 8)
1326
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
1327
+
1328
+
1329
+ def resnext101_64x4d(
1330
+ *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
1331
+ ) -> ResNet:
1332
+ weights = ResNeXt101_64X4D_Weights.verify(weights)
1333
+
1334
+ _ovewrite_named_param(kwargs, "groups", 64)
1335
+ _ovewrite_named_param(kwargs, "width_per_group", 4)
1336
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
1337
+
1338
+
1339
+ @handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
1340
+ def wide_resnet50_2(
1341
+ *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
1342
+ ) -> ResNet:
1343
+ weights = Wide_ResNet50_2_Weights.verify(weights)
1344
+
1345
+ _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
1346
+ return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
1347
+
1348
+
1349
+ @handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
1350
+ def wide_resnet101_2(
1351
+ *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
1352
+ ) -> ResNet:
1353
+ weights = Wide_ResNet101_2_Weights.verify(weights)
1354
+
1355
+ _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
1356
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
1357
+
1358
+
1359
+ # The dictionary below is internal implementation detail and will be removed in v0.15
1360
+ from torchvision.models._utils import _ModelURLs
1361
+
1362
+
1363
+ model_urls = _ModelURLs(
1364
+ {
1365
+ "resnet18": ResNet18_Weights.IMAGENET1K_V1.url,
1366
+ "resnet34": ResNet34_Weights.IMAGENET1K_V1.url,
1367
+ "resnet50": ResNet50_Weights.IMAGENET1K_V1.url,
1368
+ "resnet101": ResNet101_Weights.IMAGENET1K_V1.url,
1369
+ "resnet152": ResNet152_Weights.IMAGENET1K_V1.url,
1370
+ "resnext50_32x4d": ResNeXt50_32X4D_Weights.IMAGENET1K_V1.url,
1371
+ "resnext101_32x8d": ResNeXt101_32X8D_Weights.IMAGENET1K_V1.url,
1372
+ "wide_resnet50_2": Wide_ResNet50_2_Weights.IMAGENET1K_V1.url,
1373
+ "wide_resnet101_2": Wide_ResNet101_2_Weights.IMAGENET1K_V1.url,
1374
+ }
1375
+ )
1376
+
1377
+
1378
+ if __name__ == '__main__':
1379
+ model = resnet50_lora(r=16, lora_alpha=16, weights='ResNet50_Weights.IMAGENET1K_V2')
1380
+
1381
+
1382
+
models/transformer_decoder.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, List
3
+ import pickle as cp
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, Tensor
8
+ # from visualizer import get_local
9
+
10
+ class TransformerDecoder(nn.Module):
11
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
12
+ super().__init__()
13
+ self.layers = _get_clones(decoder_layer, num_layers)
14
+ self.num_layers = num_layers
15
+ self.norm = norm
16
+ self.return_intermediate = return_intermediate
17
+
18
+ def forward(self,tgt, memory,
19
+ tgt_mask: Optional[Tensor] = None,
20
+ memory_mask: Optional[Tensor] = None,
21
+ tgt_key_padding_mask: Optional[Tensor] = None,
22
+ memory_key_padding_mask: Optional[Tensor] = None,
23
+ pos: Optional[Tensor] = None,
24
+ query_pos: Optional[Tensor] = None):
25
+ output = tgt
26
+ T,B,C = memory.shape
27
+ intermediate = []
28
+ for n,layer in enumerate(self.layers):
29
+ residual=True
30
+ output,ws = layer(output, memory, tgt_mask=tgt_mask,
31
+ memory_mask=memory_mask,
32
+ tgt_key_padding_mask=tgt_key_padding_mask,
33
+ memory_key_padding_mask=memory_key_padding_mask,
34
+ pos=pos, query_pos=query_pos,residual=residual)
35
+ if self.return_intermediate:
36
+ intermediate.append(self.norm(output))
37
+ if self.norm is not None:
38
+ output = self.norm(output)
39
+ if self.return_intermediate:
40
+ intermediate.pop()
41
+ intermediate.append(output)
42
+
43
+ if self.return_intermediate:
44
+ return torch.stack(intermediate)
45
+ return output
46
+
47
+
48
+ class TransformerDecoderLayer(nn.Module):
49
+ def __init__(self, d_model, nhead, dim_feedforward=1024, dropout=0.1,
50
+ activation="relu", normalize_before=False):
51
+ super().__init__()
52
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
53
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
54
+ # Implementation of Feedforward model
55
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
56
+ self.dropout = nn.Dropout(dropout)
57
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
58
+
59
+ self.norm1 = nn.LayerNorm(d_model)
60
+ self.norm2 = nn.LayerNorm(d_model)
61
+ self.norm3 = nn.LayerNorm(d_model)
62
+ self.dropout1 = nn.Dropout(dropout)
63
+ self.dropout2 = nn.Dropout(dropout)
64
+ self.dropout3 = nn.Dropout(dropout)
65
+
66
+ self.activation = _get_activation_fn(activation)
67
+ self.normalize_before = normalize_before
68
+
69
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
70
+ return tensor if pos is None else tensor + pos
71
+
72
+ def forward_post(self, tgt, memory,
73
+ tgt_mask: Optional[Tensor] = None,
74
+ memory_mask: Optional[Tensor] = None,
75
+ tgt_key_padding_mask: Optional[Tensor] = None,
76
+ memory_key_padding_mask: Optional[Tensor] = None,
77
+ pos: Optional[Tensor] = None,
78
+ query_pos: Optional[Tensor] = None,
79
+ residual=True):
80
+ #tgt: text_features, torch.Size([14, 1, 768])
81
+ #memory: image_features, torch.Size([49, 1, 768])
82
+ q = k = self.with_pos_embed(tgt, query_pos)
83
+ tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
84
+ key_padding_mask=tgt_key_padding_mask)
85
+ tgt = self.norm1(tgt)
86
+ tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
87
+ key=self.with_pos_embed(memory, pos),
88
+ need_weights = True,
89
+ value=memory, attn_mask=memory_mask,
90
+ key_padding_mask=memory_key_padding_mask)
91
+
92
+
93
+ # attn_weights [B,NUM_Q,T]
94
+ tgt = tgt + self.dropout2(tgt2)
95
+ tgt = self.norm2(tgt)
96
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
97
+ tgt = tgt + self.dropout3(tgt2)
98
+ tgt = self.norm3(tgt)
99
+ return tgt,ws
100
+
101
+ # @get_local('attn_weights')
102
+ def forward_pre(self, tgt, memory,
103
+ tgt_mask: Optional[Tensor] = None,
104
+ memory_mask: Optional[Tensor] = None,
105
+ tgt_key_padding_mask: Optional[Tensor] = None,
106
+ memory_key_padding_mask: Optional[Tensor] = None,
107
+ pos: Optional[Tensor] = None,
108
+ query_pos: Optional[Tensor] = None):
109
+ tgt2 = self.norm1(tgt)
110
+ q = k = self.with_pos_embed(tgt2, query_pos)
111
+ tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
112
+ key_padding_mask=tgt_key_padding_mask)
113
+ # print('self atten',ws.shape)
114
+ tgt = tgt + self.dropout1(tgt2)
115
+ tgt2 = self.norm2(tgt)
116
+ tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
117
+ key=self.with_pos_embed(memory, pos),
118
+ value=memory, attn_mask=memory_mask,
119
+ key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
120
+ # print('self attn_weights',attn_weights.shape)
121
+ tgt = tgt + self.dropout2(tgt2)
122
+ tgt2 = self.norm3(tgt)
123
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
124
+ tgt = tgt + self.dropout3(tgt2)
125
+ return tgt,attn_weights
126
+
127
+ def forward(self, tgt, memory,
128
+ tgt_mask: Optional[Tensor] = None,
129
+ memory_mask: Optional[Tensor] = None,
130
+ tgt_key_padding_mask: Optional[Tensor] = None,
131
+ memory_key_padding_mask: Optional[Tensor] = None,
132
+ pos: Optional[Tensor] = None,
133
+ query_pos: Optional[Tensor] = None,
134
+ residual=True):
135
+ if self.normalize_before:
136
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
137
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
138
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
139
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
140
+
141
+ class TransformerDecoderV1(nn.Module):
142
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
143
+ super().__init__()
144
+ self.layers = _get_clones(decoder_layer, num_layers)
145
+ self.num_layers = num_layers
146
+ self.norm = norm
147
+ self.return_intermediate = return_intermediate
148
+
149
+ def forward(self,tgt, memory,
150
+ memory_global, #
151
+ tgt_mask: Optional[Tensor] = None,
152
+ memory_mask: Optional[Tensor] = None,
153
+ tgt_key_padding_mask: Optional[Tensor] = None,
154
+ memory_key_padding_mask: Optional[Tensor] = None,
155
+ pos: Optional[Tensor] = None,
156
+ query_pos: Optional[Tensor] = None):
157
+ output = tgt
158
+ T,B,C = memory.shape
159
+ intermediate = []
160
+ for n,layer in enumerate(self.layers):
161
+ residual=True
162
+ output,ws = layer(output, memory,
163
+ memory_global, #
164
+
165
+ tgt_mask=tgt_mask,
166
+ memory_mask=memory_mask,
167
+ tgt_key_padding_mask=tgt_key_padding_mask,
168
+ memory_key_padding_mask=memory_key_padding_mask,
169
+ pos=pos, query_pos=query_pos,residual=residual)
170
+
171
+ if self.return_intermediate:
172
+ intermediate.append(self.norm(output))
173
+ if self.norm is not None:
174
+ output = self.norm(output)
175
+ #
176
+
177
+ if self.return_intermediate:
178
+ intermediate.pop()
179
+ intermediate.append(output)
180
+
181
+ if self.return_intermediate:
182
+ return torch.stack(intermediate)
183
+
184
+ return output,ws
185
+
186
+
187
+ class TransformerDecoderLayerV1(nn.Module):
188
+ def __init__(self, d_model, nhead, dim_feedforward=1024, dropout=0.1,
189
+ activation="relu", normalize_before=False, lam = [1,0]):
190
+ super().__init__()
191
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
192
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
193
+ # Implementation of Feedforward model
194
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
195
+ self.dropout = nn.Dropout(dropout)
196
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
197
+
198
+ self.norm1 = nn.LayerNorm(d_model)
199
+ self.norm2 = nn.LayerNorm(d_model)
200
+ self.norm3 = nn.LayerNorm(d_model)
201
+ self.dropout1 = nn.Dropout(dropout)
202
+ self.dropout2 = nn.Dropout(dropout)
203
+ self.dropout3 = nn.Dropout(dropout)
204
+
205
+ self.activation = _get_activation_fn(activation)
206
+ self.normalize_before = normalize_before
207
+
208
+ # self.lam = lam
209
+ self.lam_l = lam[0]
210
+ self.lam_g = lam[1]
211
+
212
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
213
+ return tensor if pos is None else tensor + pos
214
+
215
+ def forward_post(self, tgt, memory,
216
+ memory_global,
217
+ tgt_mask: Optional[Tensor] = None,
218
+ memory_mask: Optional[Tensor] = None,
219
+ tgt_key_padding_mask: Optional[Tensor] = None,
220
+ memory_key_padding_mask: Optional[Tensor] = None,
221
+ pos: Optional[Tensor] = None,
222
+ query_pos: Optional[Tensor] = None,
223
+ residual=True):
224
+ #tgt: text_features, torch.Size([14, 1, 768])
225
+ #memory: image_features, torch.Size([49, 1, 768])
226
+ q = k = self.with_pos_embed(tgt, query_pos)
227
+ tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
228
+ key_padding_mask=tgt_key_padding_mask)
229
+ tgt = self.norm1(tgt)
230
+ tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
231
+ key=self.with_pos_embed(memory, pos),
232
+ need_weights = True,
233
+ value=memory, attn_mask=memory_mask,
234
+ key_padding_mask=memory_key_padding_mask)
235
+
236
+
237
+ # attn_weights [B,NUM_Q,T]
238
+ tgt = tgt + self.dropout2(tgt2)
239
+ tgt = self.norm2(tgt)
240
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
241
+ tgt = tgt + self.dropout3(tgt2)
242
+ tgt = self.norm3(tgt)
243
+ return tgt,ws
244
+
245
+ # @get_local('attn_weights')
246
+ def forward_pre(self, tgt, memory,
247
+ memory_global, #
248
+ tgt_mask: Optional[Tensor] = None,
249
+ memory_mask: Optional[Tensor] = None,
250
+ tgt_key_padding_mask: Optional[Tensor] = None,
251
+ memory_key_padding_mask: Optional[Tensor] = None,
252
+ pos: Optional[Tensor] = None,
253
+ query_pos: Optional[Tensor] = None):
254
+ tgt2 = self.norm1(tgt)
255
+ q = k = self.with_pos_embed(tgt2, query_pos)
256
+ tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
257
+ key_padding_mask=tgt_key_padding_mask)
258
+ # print('self atten',ws.shape)
259
+ tgt = tgt + self.dropout1(tgt2)
260
+ tgt2 = self.norm2(tgt)
261
+ if memory.shape[0] == 1:
262
+ tgt2_fine,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
263
+ key=self.with_pos_embed(memory, pos),
264
+ value=memory, attn_mask=memory_mask,
265
+ key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
266
+ tgt2 = tgt2_fine
267
+ attn_weights = attn_weights
268
+
269
+ else:
270
+ tgt2_fine,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
271
+ key=self.with_pos_embed(memory, pos),
272
+ value=memory, attn_mask=memory_mask,
273
+ key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
274
+ tgt2_global,attn_weights_global = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
275
+ key=self.with_pos_embed(memory_global, pos),
276
+ value=memory_global, attn_mask=memory_mask,
277
+ key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
278
+ # tgt2 = tgt2_fine + self.lam * tgt2_global
279
+ tgt2 = tgt2_fine*self.lam_l + tgt2_global*self.lam_g
280
+ # attn_weights = attn_weights + self.lam * attn_weights_global
281
+ attn_weights = attn_weights*self.lam_l + attn_weights_global*self.lam_g
282
+
283
+ tgt = tgt + self.dropout2(tgt2)
284
+ tgt2 = self.norm3(tgt)
285
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
286
+ tgt = tgt + self.dropout3(tgt2)
287
+
288
+ return tgt, attn_weights
289
+
290
+ def forward(self, tgt, memory,
291
+ memory_global,
292
+ tgt_mask: Optional[Tensor] = None,
293
+ memory_mask: Optional[Tensor] = None,
294
+ tgt_key_padding_mask: Optional[Tensor] = None,
295
+ memory_key_padding_mask: Optional[Tensor] = None,
296
+ pos: Optional[Tensor] = None,
297
+ query_pos: Optional[Tensor] = None,
298
+ residual=True):
299
+ if self.normalize_before:
300
+ return self.forward_pre(tgt, memory, memory_global, tgt_mask, memory_mask,
301
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
302
+ return self.forward_post(tgt, memory, memory_global, tgt_mask, memory_mask,
303
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
304
+
305
+
306
+
307
+ def _get_clones(module, N):
308
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
309
+
310
+
311
+
312
+ def _get_activation_fn(activation):
313
+ """Return an activation function given a string"""
314
+ if activation == "relu":
315
+ return F.relu
316
+ if activation == "gelu":
317
+ return F.gelu
318
+ if activation == "glu":
319
+ return F.glu
320
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")