Upload 8 files
Browse files- factory/.DS_Store +0 -0
- factory/loss.py +269 -0
- factory/metric.py +23 -0
- factory/utils.py +403 -0
- models/.DS_Store +0 -0
- models/clip_tqn.py +546 -0
- models/resnet.py +1382 -0
- models/transformer_decoder.py +320 -0
factory/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
factory/loss.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import torch.distributed.nn
|
| 9 |
+
from torch import distributed as dist
|
| 10 |
+
has_distributed = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
has_distributed = False
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import horovod.torch as hvd
|
| 16 |
+
except ImportError:
|
| 17 |
+
hvd = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def gather_features(
|
| 21 |
+
image_features,
|
| 22 |
+
text_features,
|
| 23 |
+
local_loss=False,
|
| 24 |
+
gather_with_grad=False,
|
| 25 |
+
rank=0,
|
| 26 |
+
world_size=1,
|
| 27 |
+
use_horovod=False
|
| 28 |
+
):
|
| 29 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
| 30 |
+
if use_horovod:
|
| 31 |
+
assert hvd is not None, 'Please install horovod'
|
| 32 |
+
if gather_with_grad:
|
| 33 |
+
all_image_features = hvd.allgather(image_features)
|
| 34 |
+
all_text_features = hvd.allgather(text_features)
|
| 35 |
+
else:
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
all_image_features = hvd.allgather(image_features)
|
| 38 |
+
all_text_features = hvd.allgather(text_features)
|
| 39 |
+
if not local_loss:
|
| 40 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
| 41 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
| 42 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
| 43 |
+
gathered_image_features[rank] = image_features
|
| 44 |
+
gathered_text_features[rank] = text_features
|
| 45 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 46 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 47 |
+
else:
|
| 48 |
+
# We gather tensors from all gpus
|
| 49 |
+
if gather_with_grad:
|
| 50 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
| 51 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
| 52 |
+
else:
|
| 53 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
| 54 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
| 55 |
+
dist.all_gather(gathered_image_features, image_features)
|
| 56 |
+
dist.all_gather(gathered_text_features, text_features)
|
| 57 |
+
if not local_loss:
|
| 58 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
| 59 |
+
gathered_image_features[rank] = image_features
|
| 60 |
+
gathered_text_features[rank] = text_features
|
| 61 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 62 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 63 |
+
|
| 64 |
+
return all_image_features, all_text_features
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ClipLoss(nn.Module):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
local_loss=False,
|
| 71 |
+
gather_with_grad=False,
|
| 72 |
+
cache_labels=False,
|
| 73 |
+
rank=0,
|
| 74 |
+
world_size=1,
|
| 75 |
+
use_horovod=False,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.local_loss = local_loss
|
| 79 |
+
self.gather_with_grad = gather_with_grad
|
| 80 |
+
self.cache_labels = cache_labels
|
| 81 |
+
self.rank = rank
|
| 82 |
+
self.world_size = world_size
|
| 83 |
+
self.use_horovod = use_horovod
|
| 84 |
+
|
| 85 |
+
# cache state
|
| 86 |
+
self.prev_num_logits = 0
|
| 87 |
+
self.labels = {}
|
| 88 |
+
|
| 89 |
+
def forward(self, image_features, text_features):
|
| 90 |
+
logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 91 |
+
device = image_features.device
|
| 92 |
+
if self.world_size > 1:
|
| 93 |
+
all_image_features, all_text_features = gather_features(
|
| 94 |
+
image_features, text_features,
|
| 95 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
| 96 |
+
|
| 97 |
+
if self.local_loss:
|
| 98 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
| 99 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
| 100 |
+
else:
|
| 101 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
| 102 |
+
logits_per_text = logits_per_image.T
|
| 103 |
+
else:
|
| 104 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
| 105 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
| 106 |
+
|
| 107 |
+
# calculated ground-truth and cache if enabled
|
| 108 |
+
num_logits = logits_per_image.shape[0]
|
| 109 |
+
labels = torch.eye(num_logits, device=device, dtype=torch.float)
|
| 110 |
+
pred_1 = F.log_softmax(logits_per_image,dim=-1)
|
| 111 |
+
pred_2 = F.log_softmax(logits_per_text,dim=-1)
|
| 112 |
+
loss_a = F.kl_div(pred_1, labels,reduction = 'sum')/num_logits
|
| 113 |
+
loss_b = F.kl_div(pred_2, labels,reduction = 'sum')/num_logits
|
| 114 |
+
total_loss = (loss_a + loss_b)/2
|
| 115 |
+
return total_loss
|
| 116 |
+
|
| 117 |
+
class AsymmetricLoss(nn.Module):
|
| 118 |
+
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
|
| 119 |
+
super(AsymmetricLoss, self).__init__()
|
| 120 |
+
|
| 121 |
+
self.gamma_neg = gamma_neg
|
| 122 |
+
self.gamma_pos = gamma_pos
|
| 123 |
+
self.clip = clip
|
| 124 |
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
| 125 |
+
self.eps = eps
|
| 126 |
+
|
| 127 |
+
def forward(self, x, y, use_weight = False):
|
| 128 |
+
""""
|
| 129 |
+
Parameters
|
| 130 |
+
----------
|
| 131 |
+
x: input logits
|
| 132 |
+
y: targets (multi-label binarized vector)
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
# Calculating Probabilities
|
| 136 |
+
x_sigmoid = torch.sigmoid(x)
|
| 137 |
+
xs_pos = x_sigmoid
|
| 138 |
+
xs_neg = 1 - x_sigmoid
|
| 139 |
+
|
| 140 |
+
# Asymmetric Clipping
|
| 141 |
+
if self.clip is not None and self.clip > 0:
|
| 142 |
+
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
| 143 |
+
|
| 144 |
+
# Basic CE calculation
|
| 145 |
+
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
|
| 146 |
+
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
|
| 147 |
+
loss = los_pos + los_neg
|
| 148 |
+
|
| 149 |
+
# Asymmetric Focusing
|
| 150 |
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
| 151 |
+
if self.disable_torch_grad_focal_loss:
|
| 152 |
+
torch.set_grad_enabled(False)
|
| 153 |
+
pt0 = xs_pos * y
|
| 154 |
+
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
| 155 |
+
pt = pt0 + pt1
|
| 156 |
+
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
| 157 |
+
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
| 158 |
+
if self.disable_torch_grad_focal_loss:
|
| 159 |
+
torch.set_grad_enabled(True)
|
| 160 |
+
loss *= one_sided_w
|
| 161 |
+
if use_weight:
|
| 162 |
+
return loss
|
| 163 |
+
return -loss.sum()
|
| 164 |
+
|
| 165 |
+
class RalSingleLoss(nn.Module):
|
| 166 |
+
'''
|
| 167 |
+
This loss is intended for single-label classification problems
|
| 168 |
+
'''
|
| 169 |
+
def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, epsilon_pos_pow = -2.5, reduction='mean'):
|
| 170 |
+
super(RalSingleLoss, self).__init__()
|
| 171 |
+
|
| 172 |
+
self.eps = eps
|
| 173 |
+
self.logsoftmax = nn.LogSoftmax(dim=-1)
|
| 174 |
+
self.targets_classes = []
|
| 175 |
+
self.gamma_pos = gamma_pos
|
| 176 |
+
self.gamma_neg = gamma_neg
|
| 177 |
+
self.reduction = reduction
|
| 178 |
+
self.epsilon_pos = 1.0
|
| 179 |
+
self.epsilon_neg = 0.0
|
| 180 |
+
self.epsilon_pos_pow = epsilon_pos_pow
|
| 181 |
+
self.lamb = 1.5
|
| 182 |
+
|
| 183 |
+
def forward(self, inputs, target):
|
| 184 |
+
'''
|
| 185 |
+
"input" dimensions: - (batch_size,number_classes)
|
| 186 |
+
"target" dimensions: - (batch_size)
|
| 187 |
+
'''
|
| 188 |
+
num_classes = inputs.size()[-1]
|
| 189 |
+
log_preds = self.logsoftmax(inputs)
|
| 190 |
+
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
|
| 191 |
+
|
| 192 |
+
# ASL weights
|
| 193 |
+
targets = self.targets_classes
|
| 194 |
+
anti_targets = 1 - targets
|
| 195 |
+
xs_pos = torch.exp(log_preds)
|
| 196 |
+
xs_neg = 1 - xs_pos
|
| 197 |
+
xs_pos = torch.exp(log_preds)* (torch.log(xs_pos.clamp(min=self.eps)) + self.epsilon_pos * (1 - xs_pos.clamp(min=self.eps)) + self.epsilon_pos_pow * 0.5 * torch.pow(1 - xs_pos.clamp(min=self.eps), 2) ) * torch.log(xs_pos)
|
| 198 |
+
xs_neg = (1 - xs_pos) * (torch.log(xs_neg.clamp(min=self.eps)) + self.epsilon_neg * (xs_neg.clamp(min=self.eps)) ) * -(self.lamb - xs_neg) * xs_neg ** 2
|
| 199 |
+
asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
|
| 200 |
+
self.gamma_pos * targets + self.gamma_neg * anti_targets)
|
| 201 |
+
log_preds = log_preds * asymmetric_w
|
| 202 |
+
|
| 203 |
+
if self.eps > 0: # label smoothing
|
| 204 |
+
self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
|
| 205 |
+
|
| 206 |
+
# loss calculation
|
| 207 |
+
loss = - self.targets_classes.mul(log_preds)
|
| 208 |
+
|
| 209 |
+
loss = loss.sum(dim=-1)
|
| 210 |
+
if self.reduction == 'mean':
|
| 211 |
+
loss = loss.mean()
|
| 212 |
+
|
| 213 |
+
return loss
|
| 214 |
+
|
| 215 |
+
class Ralloss(nn.Module):
|
| 216 |
+
def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-8, lamb=1.5, epsilon_neg=0.0, epsilon_pos=1.0, epsilon_pos_pow=-2.5, disable_torch_grad_focal_loss=False):
|
| 217 |
+
super(Ralloss, self).__init__()
|
| 218 |
+
|
| 219 |
+
self.gamma_neg = gamma_neg
|
| 220 |
+
self.gamma_pos = gamma_pos
|
| 221 |
+
self.clip = clip
|
| 222 |
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
| 223 |
+
self.eps = eps
|
| 224 |
+
|
| 225 |
+
# parameters of Taylor expansion polynomials
|
| 226 |
+
self.epsilon_pos = epsilon_pos
|
| 227 |
+
self.epsilon_neg = epsilon_neg
|
| 228 |
+
self.epsilon_pos_pow = epsilon_pos_pow
|
| 229 |
+
self.margin = 1.0
|
| 230 |
+
self.lamb = lamb
|
| 231 |
+
|
| 232 |
+
def forward(self, x, y, use_weight=False):
|
| 233 |
+
""""
|
| 234 |
+
x: input logits with size (batch_size, number of labels).
|
| 235 |
+
y: binarized multi-label targets with size (batch_size, number of labels).
|
| 236 |
+
"""
|
| 237 |
+
# Calculating Probabilities
|
| 238 |
+
x_sigmoid = torch.sigmoid(x)
|
| 239 |
+
xs_pos = x_sigmoid
|
| 240 |
+
xs_neg = 1 - x_sigmoid
|
| 241 |
+
|
| 242 |
+
# Asymmetric Clipping
|
| 243 |
+
if self.clip is not None and self.clip > 0:
|
| 244 |
+
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
| 245 |
+
|
| 246 |
+
# Basic Taylor expansion polynomials
|
| 247 |
+
los_pos = y * (torch.log(xs_pos.clamp(min=self.eps)) + self.epsilon_pos * (1 - xs_pos.clamp(min=self.eps)) + self.epsilon_pos_pow * 0.5 * torch.pow(1 - xs_pos.clamp(min=self.eps), 2))
|
| 248 |
+
los_neg = (1 - y) * (torch.log(xs_neg.clamp(min=self.eps)) + self.epsilon_neg * (xs_neg.clamp(min=self.eps)) ) * (self.lamb - x_sigmoid) * x_sigmoid ** 2 * (self.lamb - xs_neg)
|
| 249 |
+
loss = los_pos + los_neg
|
| 250 |
+
|
| 251 |
+
# Asymmetric Focusing
|
| 252 |
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
| 253 |
+
if self.disable_torch_grad_focal_loss:
|
| 254 |
+
torch.set_grad_enabled(False)
|
| 255 |
+
pt0 = xs_pos * y
|
| 256 |
+
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
| 257 |
+
pt = pt0 + pt1
|
| 258 |
+
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
| 259 |
+
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
| 260 |
+
if self.disable_torch_grad_focal_loss:
|
| 261 |
+
torch.set_grad_enabled(True)
|
| 262 |
+
loss *= one_sided_w
|
| 263 |
+
if use_weight:
|
| 264 |
+
return loss
|
| 265 |
+
return -loss.sum()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
factory/metric.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sklearn
|
| 2 |
+
from sklearn.metrics import roc_auc_score,precision_recall_curve,accuracy_score,matthews_corrcoef
|
| 3 |
+
|
| 4 |
+
# gt_np = [0,1,1,0]
|
| 5 |
+
# pred_np = [0.2,0.7,0.8,0.3]
|
| 6 |
+
# fps,tps,thresholds = sklearn.metrics._ranking._binary_clf_curve(gt_np, pred_np)
|
| 7 |
+
# precision, recall, thresholds = precision_recall_curve(gt_np, pred_np)
|
| 8 |
+
# print(fps,tps,thresholds) # tps An increasing count of true positives
|
| 9 |
+
# print(precision[:-1], recall[:-1], thresholds)
|
| 10 |
+
# tns = tps/recall[:-1] - tps
|
| 11 |
+
|
| 12 |
+
# print(tns + fps + tps)
|
| 13 |
+
# fns =
|
| 14 |
+
# [0. 1. 2.] [2. 2. 2.] [0.8 0.3 0.2]
|
| 15 |
+
# precision, recall, thresholds = precision_recall_curve(gt_np, pred_np)
|
| 16 |
+
# numerator = 2 * recall * precision
|
| 17 |
+
# denom = recall + precision
|
| 18 |
+
# f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
|
| 19 |
+
# max_f1 = np.max(f1_scores)
|
| 20 |
+
# max_f1_thresh = thresholds[np.argmax(f1_scores)]
|
| 21 |
+
# print('The max_f1_thresh is', max_f1_thresh)
|
| 22 |
+
# print('The average f1_score is', max_f1)
|
| 23 |
+
# print('The average acc_score is', accuracy_score(gt_np, pred_np>max_f1_thresh))
|
factory/utils.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import random
|
| 6 |
+
from collections import defaultdict, deque
|
| 7 |
+
import datetime
|
| 8 |
+
import subprocess
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
class SmoothedValue(object):
|
| 13 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 14 |
+
window or the global series average.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, window_size=20, fmt=None):
|
| 18 |
+
if fmt is None:
|
| 19 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 20 |
+
self.deque = deque(maxlen=window_size)
|
| 21 |
+
self.total = 0.0
|
| 22 |
+
self.count = 0
|
| 23 |
+
self.fmt = fmt
|
| 24 |
+
|
| 25 |
+
def update(self, value, n=1):
|
| 26 |
+
self.deque.append(value)
|
| 27 |
+
self.count += n
|
| 28 |
+
self.total += value * n
|
| 29 |
+
|
| 30 |
+
def synchronize_between_processes(self):
|
| 31 |
+
"""
|
| 32 |
+
Warning: does not synchronize the deque!
|
| 33 |
+
"""
|
| 34 |
+
if not is_dist_avail_and_initialized():
|
| 35 |
+
return
|
| 36 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 37 |
+
dist.barrier()
|
| 38 |
+
dist.all_reduce(t)
|
| 39 |
+
t = t.tolist()
|
| 40 |
+
self.count = int(t[0])
|
| 41 |
+
self.total = t[1]
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def median(self):
|
| 45 |
+
d = torch.tensor(list(self.deque))
|
| 46 |
+
return d.median().item()
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def avg(self):
|
| 50 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 51 |
+
return d.mean().item()
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def global_avg(self):
|
| 55 |
+
if self.count == 0:
|
| 56 |
+
return self.total
|
| 57 |
+
else:
|
| 58 |
+
return self.total / self.count
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def max(self):
|
| 62 |
+
return max(self.deque)
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def value(self):
|
| 66 |
+
return self.deque[-1]
|
| 67 |
+
|
| 68 |
+
def __str__(self):
|
| 69 |
+
return self.fmt.format(
|
| 70 |
+
median=self.median,
|
| 71 |
+
avg=self.avg,
|
| 72 |
+
global_avg=self.global_avg,
|
| 73 |
+
max=self.max,
|
| 74 |
+
value=self.value
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MetricLogger(object):
|
| 79 |
+
def __init__(self, delimiter="\t"):
|
| 80 |
+
self.meters = defaultdict(SmoothedValue)
|
| 81 |
+
self.delimiter = delimiter
|
| 82 |
+
|
| 83 |
+
def update(self, **kwargs):
|
| 84 |
+
for k, v in kwargs.items():
|
| 85 |
+
if isinstance(v, torch.Tensor):
|
| 86 |
+
v = v.item()
|
| 87 |
+
assert isinstance(v, (float, int))
|
| 88 |
+
self.meters[k].update(v)
|
| 89 |
+
|
| 90 |
+
def __getattr__(self, attr):
|
| 91 |
+
if attr in self.meters:
|
| 92 |
+
return self.meters[attr]
|
| 93 |
+
if attr in self.__dict__:
|
| 94 |
+
return self.__dict__[attr]
|
| 95 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 96 |
+
type(self).__name__, attr))
|
| 97 |
+
|
| 98 |
+
def __str__(self):
|
| 99 |
+
loss_str = []
|
| 100 |
+
for name, meter in self.meters.items():
|
| 101 |
+
loss_str.append(
|
| 102 |
+
"{}: {}".format(name, str(meter))
|
| 103 |
+
)
|
| 104 |
+
return self.delimiter.join(loss_str)
|
| 105 |
+
|
| 106 |
+
def global_avg(self):
|
| 107 |
+
loss_str = []
|
| 108 |
+
for name, meter in self.meters.items():
|
| 109 |
+
loss_str.append(
|
| 110 |
+
"{}: {:.4f}".format(name, meter.global_avg)
|
| 111 |
+
)
|
| 112 |
+
return self.delimiter.join(loss_str)
|
| 113 |
+
|
| 114 |
+
def synchronize_between_processes(self):
|
| 115 |
+
for meter in self.meters.values():
|
| 116 |
+
meter.synchronize_between_processes()
|
| 117 |
+
|
| 118 |
+
def add_meter(self, name, meter):
|
| 119 |
+
self.meters[name] = meter
|
| 120 |
+
|
| 121 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 122 |
+
i = 0
|
| 123 |
+
if not header:
|
| 124 |
+
header = ''
|
| 125 |
+
start_time = time.time()
|
| 126 |
+
end = time.time()
|
| 127 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 128 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 129 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 130 |
+
log_msg = [
|
| 131 |
+
header,
|
| 132 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 133 |
+
'eta: {eta}',
|
| 134 |
+
'{meters}',
|
| 135 |
+
'time: {time}',
|
| 136 |
+
'data: {data}'
|
| 137 |
+
]
|
| 138 |
+
if torch.cuda.is_available():
|
| 139 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 140 |
+
log_msg = self.delimiter.join(log_msg)
|
| 141 |
+
MB = 1024.0 * 1024.0
|
| 142 |
+
for obj in iterable:
|
| 143 |
+
data_time.update(time.time() - end)
|
| 144 |
+
yield obj
|
| 145 |
+
iter_time.update(time.time() - end)
|
| 146 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 149 |
+
if torch.cuda.is_available():
|
| 150 |
+
print(log_msg.format(
|
| 151 |
+
i, len(iterable), eta=eta_string,
|
| 152 |
+
meters=str(self),
|
| 153 |
+
time=str(iter_time), data=str(data_time),
|
| 154 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 155 |
+
else:
|
| 156 |
+
print(log_msg.format(
|
| 157 |
+
i, len(iterable), eta=eta_string,
|
| 158 |
+
meters=str(self),
|
| 159 |
+
time=str(iter_time), data=str(data_time)))
|
| 160 |
+
i += 1
|
| 161 |
+
end = time.time()
|
| 162 |
+
total_time = time.time() - start_time
|
| 163 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 164 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 165 |
+
header, total_time_str, total_time / len(iterable)))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AttrDict(dict):
|
| 170 |
+
def __init__(self, *args, **kwargs):
|
| 171 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 172 |
+
self.__dict__ = self
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def compute_acc(logits, label, reduction='mean'):
|
| 176 |
+
ret = (torch.argmax(logits, dim=1) == label).float()
|
| 177 |
+
if reduction == 'none':
|
| 178 |
+
return ret.detach()
|
| 179 |
+
elif reduction == 'mean':
|
| 180 |
+
return ret.mean().item()
|
| 181 |
+
|
| 182 |
+
def compute_n_params(model, return_str=True):
|
| 183 |
+
tot = 0
|
| 184 |
+
for p in model.parameters():
|
| 185 |
+
w = 1
|
| 186 |
+
for x in p.shape:
|
| 187 |
+
w *= x
|
| 188 |
+
tot += w
|
| 189 |
+
if return_str:
|
| 190 |
+
if tot >= 1e6:
|
| 191 |
+
return '{:.1f}M'.format(tot / 1e6)
|
| 192 |
+
else:
|
| 193 |
+
return '{:.1f}K'.format(tot / 1e3)
|
| 194 |
+
else:
|
| 195 |
+
return tot
|
| 196 |
+
|
| 197 |
+
def setup_for_distributed(is_master):
|
| 198 |
+
"""
|
| 199 |
+
This function disables printing when not in master process
|
| 200 |
+
"""
|
| 201 |
+
import builtins as __builtin__
|
| 202 |
+
builtin_print = __builtin__.print
|
| 203 |
+
|
| 204 |
+
def print(*args, **kwargs):
|
| 205 |
+
force = kwargs.pop('force', False)
|
| 206 |
+
if is_master or force:
|
| 207 |
+
builtin_print(*args, **kwargs)
|
| 208 |
+
|
| 209 |
+
__builtin__.print = print
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def seed_worker(worker_id):
|
| 213 |
+
worker_seed = torch.initial_seed() % 2**32
|
| 214 |
+
np.random.seed(worker_seed)
|
| 215 |
+
random.seed(worker_seed)
|
| 216 |
+
|
| 217 |
+
def is_dist_avail_and_initialized():
|
| 218 |
+
if not dist.is_available():
|
| 219 |
+
return False
|
| 220 |
+
if not dist.is_initialized():
|
| 221 |
+
return False
|
| 222 |
+
return True
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_world_size():
|
| 226 |
+
if not is_dist_avail_and_initialized():
|
| 227 |
+
return 1
|
| 228 |
+
return dist.get_world_size()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def get_rank():
|
| 232 |
+
if not is_dist_avail_and_initialized():
|
| 233 |
+
return 0
|
| 234 |
+
return dist.get_rank()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def is_main_process():
|
| 238 |
+
return get_rank() == 0
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def save_on_master(*args, **kwargs):
|
| 242 |
+
if is_main_process():
|
| 243 |
+
torch.save(*args, **kwargs)
|
| 244 |
+
|
| 245 |
+
def init_distributed_mode(args):
|
| 246 |
+
if args.dist_on_itp:
|
| 247 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 248 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 249 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 250 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 251 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 252 |
+
os.environ['RANK'] = str(args.rank)
|
| 253 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 254 |
+
print('on tip')
|
| 255 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 256 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 257 |
+
args.rank = int(os.environ["RANK"])
|
| 258 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 259 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 260 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 261 |
+
print('rank')
|
| 262 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 263 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 264 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 265 |
+
else:
|
| 266 |
+
print('Not using distributed mode')
|
| 267 |
+
args.distributed = False
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
+
# args.distributed = False
|
| 271 |
+
|
| 272 |
+
# torch.cuda.set_device(args.gpu)
|
| 273 |
+
# args.dist_backend = 'gloo'
|
| 274 |
+
# print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 275 |
+
# args.rank, args.dist_url, args.gpu), flush=True)
|
| 276 |
+
# print("flag1")
|
| 277 |
+
# print(args)
|
| 278 |
+
# torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 279 |
+
# world_size=args.world_size, rank=args.rank)
|
| 280 |
+
|
| 281 |
+
# print("flag2")
|
| 282 |
+
# torch.distributed.barrier()
|
| 283 |
+
# setup_for_distributed(args.rank == 0)
|
| 284 |
+
|
| 285 |
+
args.distributed = False
|
| 286 |
+
args.dist_url ='tcp://localhost:12345'
|
| 287 |
+
args.world_size=1
|
| 288 |
+
args.rank = 0
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# def init_distributed_mode(args,port='29511'):
|
| 292 |
+
# num_gpus = torch.cuda.device_count()
|
| 293 |
+
# if args.dist_on_itp:
|
| 294 |
+
# args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 295 |
+
# args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 296 |
+
# args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 297 |
+
# args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 298 |
+
# os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 299 |
+
# os.environ['RANK'] = str(args.rank)
|
| 300 |
+
# os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 301 |
+
# # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 302 |
+
# elif "SLURM_JOB_ID" in os.environ:
|
| 303 |
+
# print('SLURM_JOB_ID')
|
| 304 |
+
# args.rank = int(os.environ["SLURM_PROCID"])
|
| 305 |
+
# args.world_size = int(os.environ["SLURM_NTASKS"])
|
| 306 |
+
# node_list = os.environ["SLURM_NODELIST"]
|
| 307 |
+
# addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
|
| 308 |
+
# # specify master port
|
| 309 |
+
# if port is not None:
|
| 310 |
+
# os.environ["MASTER_PORT"] = str(port)
|
| 311 |
+
# elif "MASTER_PORT" not in os.environ:
|
| 312 |
+
# os.environ["MASTER_PORT"] = "29400"
|
| 313 |
+
# if "MASTER_ADDR" not in os.environ:
|
| 314 |
+
# os.environ["MASTER_ADDR"] = addr
|
| 315 |
+
# os.environ["WORLD_SIZE"] = str(args.world_size)
|
| 316 |
+
# os.environ["LOCAL_RANK"] = str(args.rank % num_gpus)
|
| 317 |
+
# os.environ["RANK"] = os.environ["WORLD_SIZE"]
|
| 318 |
+
# args.gpu = args.rank % torch.cuda.device_count()
|
| 319 |
+
# elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 320 |
+
# print('RANK')
|
| 321 |
+
# args.rank = int(os.environ["RANK"])
|
| 322 |
+
# args.world_size = int(os.environ['WORLD_SIZE'])
|
| 323 |
+
# args.gpu = int(os.environ['LOCAL_RANK'])
|
| 324 |
+
# else:
|
| 325 |
+
# print('Not using distributed mode')
|
| 326 |
+
# args.distributed = False
|
| 327 |
+
# return
|
| 328 |
+
|
| 329 |
+
# args.distributed = True
|
| 330 |
+
|
| 331 |
+
# torch.cuda.set_device(args.gpu)
|
| 332 |
+
# args.dist_backend = 'nccl'
|
| 333 |
+
|
| 334 |
+
# print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 335 |
+
# args.rank, args.dist_url, args.gpu), flush=True)
|
| 336 |
+
# torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 337 |
+
# world_size=args.world_size, rank=args.rank)
|
| 338 |
+
# print('Init_process_group')
|
| 339 |
+
# torch.distributed.barrier()
|
| 340 |
+
# print('distributed.barrier')
|
| 341 |
+
# setup_for_distributed(args.rank == 0)
|
| 342 |
+
# print('Finished distributed')
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# def init_distributed_mode(args):
|
| 346 |
+
# # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
| 347 |
+
# # args.local_rank = os.environ['LOCAL_RANK']
|
| 348 |
+
# if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 349 |
+
# args.rank = int(os.environ["RANK"])
|
| 350 |
+
# args.world_size = int(os.environ['WORLD_SIZE'])
|
| 351 |
+
# args.local_rank = int(os.environ['LOCAL_RANK'])
|
| 352 |
+
# elif 'SLURM_PROCID' in os.environ:
|
| 353 |
+
# args.rank = int(os.environ['SLURM_PROCID'])
|
| 354 |
+
# args.local_rank = args.rank % torch.cuda.device_count()
|
| 355 |
+
# else:
|
| 356 |
+
# print('Not using distributed mode')
|
| 357 |
+
# args.distributed = False
|
| 358 |
+
# return
|
| 359 |
+
|
| 360 |
+
# args.distributed = True
|
| 361 |
+
|
| 362 |
+
# torch.cuda.set_device(args.local_rank)
|
| 363 |
+
# args.dist_backend = 'nccl'
|
| 364 |
+
# print('| distributed init (rank {}): {}'.format(
|
| 365 |
+
# args.rank, args.dist_url), flush=True)
|
| 366 |
+
# torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 367 |
+
# world_size=args.world_size, rank=args.rank)
|
| 368 |
+
# torch.distributed.barrier()
|
| 369 |
+
# setup_for_distributed(args.rank == 0)
|
| 370 |
+
|
| 371 |
+
# def init_distributed_mode(args):
|
| 372 |
+
|
| 373 |
+
# # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 374 |
+
# # args.rank = int(os.environ["RANK"])
|
| 375 |
+
# # args.world_size = int(os.environ['WORLD_SIZE'])
|
| 376 |
+
# # args.gpu = int(os.environ['LOCAL_RANK'])
|
| 377 |
+
# # elif 'SLURM_PROCID' in os.environ:
|
| 378 |
+
# # args.rank = int(os.environ['SLURM_PROCID'])
|
| 379 |
+
# # args.gpu = args.rank % torch.cuda.device_count()
|
| 380 |
+
# # else:
|
| 381 |
+
# # print('Not using distributed mode')
|
| 382 |
+
# # args.distributed = False
|
| 383 |
+
# # return
|
| 384 |
+
# # rank = int(os.environ['RANK']) # system env process ranks\
|
| 385 |
+
# # print(torch.distributed.get_world_size())
|
| 386 |
+
|
| 387 |
+
# args.distributed = True
|
| 388 |
+
# # torch.cuda.set_device(args.gpu)
|
| 389 |
+
# num_gpus = torch.cuda.device_count() # Returns the number of GPUs available
|
| 390 |
+
# torch.cuda.set_device(args.rank % num_gpus)
|
| 391 |
+
# # args.gpu = args.rank % torch.cuda.device_count()
|
| 392 |
+
|
| 393 |
+
# args.dist_backend = 'nccl'
|
| 394 |
+
# print('| distributed init (rank {}): {}'.format(
|
| 395 |
+
# args.rank, args.dist_url), flush=True)
|
| 396 |
+
# torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 397 |
+
# world_size=args.world_size, rank=args.rank)
|
| 398 |
+
# torch.distributed.barrier()
|
| 399 |
+
# print('using distributed mode',args.rank, args.dist_url)
|
| 400 |
+
# setup_for_distributed(args.rank == 0)
|
| 401 |
+
|
| 402 |
+
# # export MASTER_ADDR=localhost
|
| 403 |
+
# export MASTER_PORT=5678
|
models/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/clip_tqn.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import sys
|
| 3 |
+
# 加入父文件夹路径到sys.path中
|
| 4 |
+
sys.path.append(sys.path[0].replace('models', ''))
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import json
|
| 10 |
+
import pathlib
|
| 11 |
+
import numpy as np
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Tuple, Union, Callable, Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torchvision.models as models
|
| 24 |
+
from torch.utils.checkpoint import checkpoint
|
| 25 |
+
|
| 26 |
+
from transformers import AutoModel,BertConfig,AutoTokenizer
|
| 27 |
+
# from pytorch_pretrained_vit import ViT
|
| 28 |
+
|
| 29 |
+
# from visualizer import get_local
|
| 30 |
+
from models.transformer_decoder import *
|
| 31 |
+
|
| 32 |
+
# from io import BytesIO
|
| 33 |
+
# from petrel_client.client import Client
|
| 34 |
+
|
| 35 |
+
# conf_path = '~/petreloss.conf'
|
| 36 |
+
# client = Client(conf_path)
|
| 37 |
+
from torch.autograd import Function
|
| 38 |
+
import timm
|
| 39 |
+
|
| 40 |
+
class ReverseLayerF(Function):
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def forward(ctx, x, alpha):
|
| 44 |
+
ctx.alpha = alpha
|
| 45 |
+
return x.view_as(x)
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def backward(ctx, grad_output):
|
| 49 |
+
return grad_output.neg() * ctx.alpha, None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DomainClassifier(nn.Module):
|
| 53 |
+
'''一个单层分类器 带梯度反转层'''
|
| 54 |
+
def __init__(self, domain_nums=4, feature_dims=768):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.domain_nums = domain_nums
|
| 57 |
+
self.feature_dims = feature_dims
|
| 58 |
+
self.fc = nn.Linear(feature_dims, domain_nums)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
reverse_x = ReverseLayerF.apply(x, 1.0)
|
| 62 |
+
return self.fc(reverse_x)
|
| 63 |
+
|
| 64 |
+
class CLP_clinical(nn.Module):
|
| 65 |
+
def __init__(self,
|
| 66 |
+
bert_model_name: str,
|
| 67 |
+
embed_dim: int = 768,
|
| 68 |
+
freeze_layers:Union[Tuple[int, int], int] = None):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.bert_model = self._get_bert_basemodel(bert_model_name=bert_model_name, freeze_layers=freeze_layers)
|
| 71 |
+
self.mlp_embed = nn.Sequential(
|
| 72 |
+
nn.Linear(embed_dim, embed_dim),
|
| 73 |
+
nn.GELU(),
|
| 74 |
+
nn.Linear(embed_dim, embed_dim)
|
| 75 |
+
)
|
| 76 |
+
self.embed_dim = embed_dim
|
| 77 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 78 |
+
self.init_parameters()
|
| 79 |
+
|
| 80 |
+
def init_parameters(self):
|
| 81 |
+
nn.init.constant_(self.logit_scale, np.log(1 / 0.07))
|
| 82 |
+
for m in self.mlp_embed:
|
| 83 |
+
if isinstance(m, nn.Linear):
|
| 84 |
+
nn.init.normal_(m.weight, std=self.embed_dim ** -0.5)
|
| 85 |
+
|
| 86 |
+
def _get_bert_basemodel(self, bert_model_name, freeze_layers=None):#12
|
| 87 |
+
try:
|
| 88 |
+
print(bert_model_name)
|
| 89 |
+
config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)#bert-base-uncased
|
| 90 |
+
model = AutoModel.from_pretrained(bert_model_name, config=config)#, return_dict=True)
|
| 91 |
+
print("Text feature extractor:", bert_model_name)
|
| 92 |
+
print("bert encoder layers:",len(model.encoder.layer))
|
| 93 |
+
except:
|
| 94 |
+
raise ("Invalid model name. Check the config file and pass a BERT model from transformers lybrary")
|
| 95 |
+
|
| 96 |
+
if freeze_layers is not None:
|
| 97 |
+
for layer_idx in freeze_layers:
|
| 98 |
+
for param in list(model.encoder.layer[layer_idx].parameters()):
|
| 99 |
+
param.requires_grad = False
|
| 100 |
+
return model
|
| 101 |
+
|
| 102 |
+
def encode_text(self, text):
|
| 103 |
+
#input batch_size,token, return batch_size,dim
|
| 104 |
+
output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] )
|
| 105 |
+
last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2]
|
| 106 |
+
encode_out = self.mlp_embed(pooler_output)
|
| 107 |
+
# encode_out = pooler_output
|
| 108 |
+
return encode_out
|
| 109 |
+
|
| 110 |
+
def forward(self, text):
|
| 111 |
+
#input batch_size,token, return batch_size,dim
|
| 112 |
+
output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] )
|
| 113 |
+
last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2]
|
| 114 |
+
encode_out = self.mlp_embed(pooler_output)
|
| 115 |
+
# encode_out = pooler_output
|
| 116 |
+
return encode_out
|
| 117 |
+
|
| 118 |
+
# def forward(self,text1,text2):
|
| 119 |
+
# text1_features = self.encode_text(text1)
|
| 120 |
+
# text2_features = self.encode_text(text2)
|
| 121 |
+
# text1_features = F.normalize(text1_features, dim=-1)
|
| 122 |
+
# text2_features = F.normalize(text2_features, dim=-1)
|
| 123 |
+
# return text1_features, text2_features, self.logit_scale.exp()
|
| 124 |
+
|
| 125 |
+
class ModelRes(nn.Module):
|
| 126 |
+
def __init__(self, res_base_model):
|
| 127 |
+
super(ModelRes, self).__init__()
|
| 128 |
+
self.resnet_dict = {
|
| 129 |
+
"resnet50": models.resnet50(pretrained=True),
|
| 130 |
+
"resnet101": models.resnet101(pretrained=True),
|
| 131 |
+
"resnet152": models.resnet152(pretrained=True),
|
| 132 |
+
"resnet50_openai": None,
|
| 133 |
+
'resnet101_openai': None,
|
| 134 |
+
'resnet50x4_openai': None,
|
| 135 |
+
}
|
| 136 |
+
# "resnet50": models.resnet50(pretrained=True)}
|
| 137 |
+
self.resnet = self._get_res_basemodel(res_base_model)
|
| 138 |
+
# num_ftrs = int(self.resnet.fc.in_features/2)
|
| 139 |
+
# self.res_features = nn.Sequential(*list(self.resnet.children())[:-3]) 224
|
| 140 |
+
if 'openai' in res_base_model:
|
| 141 |
+
# 重新定义res_features
|
| 142 |
+
num_ftrs = int(self.resnet.attnpool.v_proj.in_features)
|
| 143 |
+
self.res_features = nn.Sequential(*list(self.resnet.children())[:-1])
|
| 144 |
+
else:
|
| 145 |
+
num_ftrs = int(self.resnet.fc.in_features)
|
| 146 |
+
self.res_features = nn.Sequential(*list(self.resnet.children())[:-2])
|
| 147 |
+
# here num_ftrs = 2048
|
| 148 |
+
self.res_l1 = nn.Linear(num_ftrs, num_ftrs)
|
| 149 |
+
self.res_l2 = nn.Linear(num_ftrs, 768)
|
| 150 |
+
|
| 151 |
+
def _get_res_basemodel(self, res_model_name):
|
| 152 |
+
try:
|
| 153 |
+
res_model = self.resnet_dict[res_model_name]
|
| 154 |
+
print("Image feature extractor:", res_model_name)
|
| 155 |
+
return res_model
|
| 156 |
+
except:
|
| 157 |
+
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
| 158 |
+
|
| 159 |
+
def forward(self, img):
|
| 160 |
+
#return (batchsize, patch_num, dim)
|
| 161 |
+
batch_size = img.shape[0]
|
| 162 |
+
res_fea = self.res_features(img)
|
| 163 |
+
# return res_fea
|
| 164 |
+
# res_fea = F.adaptive_avg_pool2d(res_fea, (1, 1))
|
| 165 |
+
res_fea = rearrange(res_fea,'b d n1 n2 -> b (n1 n2) d')
|
| 166 |
+
h = rearrange(res_fea,'b n d -> (b n) d')
|
| 167 |
+
x = self.res_l1(h)
|
| 168 |
+
x = F.relu(x)
|
| 169 |
+
x = self.res_l2(x)
|
| 170 |
+
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
|
| 171 |
+
out_pool = torch.mean(out_emb,dim=1)
|
| 172 |
+
return out_emb,out_pool
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class ModelConvNeXt(nn.Module):
|
| 176 |
+
def __init__(self, convnext_base_model):
|
| 177 |
+
super(ModelConvNeXt, self).__init__()
|
| 178 |
+
self.convnext_dict = {"convnext-tiny": timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000),
|
| 179 |
+
"convnext-base": timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000),
|
| 180 |
+
}
|
| 181 |
+
convnext = self._get_convnext_basemodel(convnext_base_model)
|
| 182 |
+
num_ftrs = int(convnext.head.in_features)
|
| 183 |
+
self.conv_features = nn.Sequential(*list(convnext.children())[:-2])
|
| 184 |
+
self.conv_l1 = nn.Linear(num_ftrs, num_ftrs)
|
| 185 |
+
self.conv_l2 = nn.Linear(num_ftrs, 768)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _get_convnext_basemodel(self, convnext_model_name):
|
| 189 |
+
try:
|
| 190 |
+
convnext_model = self.convnext_dict[convnext_model_name]
|
| 191 |
+
print("Image feature extractor:", convnext_model_name)
|
| 192 |
+
return convnext_model
|
| 193 |
+
except:
|
| 194 |
+
raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base")
|
| 195 |
+
|
| 196 |
+
def forward(self, img):
|
| 197 |
+
#return (batchsize, patch_num, dim)
|
| 198 |
+
batch_size = img.shape[0]
|
| 199 |
+
conv_fea = self.conv_features(img)
|
| 200 |
+
conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1))
|
| 201 |
+
conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d')
|
| 202 |
+
h = rearrange(conv_fea,'b n d -> (b n) d')
|
| 203 |
+
x = self.conv_l1(h)
|
| 204 |
+
x = F.relu(x)
|
| 205 |
+
x = self.conv_l2(x)
|
| 206 |
+
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
|
| 207 |
+
out_pool = torch.mean(out_emb,dim=1)
|
| 208 |
+
return out_emb,out_pool
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# class ModelConvNeXt(nn.Module):
|
| 212 |
+
# def __init__(self, convnext_base_model):
|
| 213 |
+
# super(ModelConvNeXt, self).__init__()
|
| 214 |
+
# self.convnext_dict = {"convnext-tiny": models.convnext_tiny(weights='ConvNeXt_Tiny_Weights.DEFAULT'),
|
| 215 |
+
# "convnext-small": models.convnext_small(weights='ConvNeXt_Small_Weights.DEFAULT'),
|
| 216 |
+
# "convnext-base": models.convnext_base(weights='ConvNeXt_Base_Weights.DEFAULT'),
|
| 217 |
+
# }
|
| 218 |
+
# convnext = self._get_convnext_basemodel(convnext_base_model)
|
| 219 |
+
# num_ftrs = int(convnext.classifier[-1].in_features)
|
| 220 |
+
# self.conv_features = nn.Sequential(*list(convnext.children())[:-2])
|
| 221 |
+
# self.conv_l1 = nn.Linear(num_ftrs, num_ftrs)
|
| 222 |
+
# self.conv_l2 = nn.Linear(num_ftrs, 768)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# def _get_convnext_basemodel(self, convnext_model_name):
|
| 226 |
+
# try:
|
| 227 |
+
# convnext_model = self.convnext_dict[convnext_model_name]
|
| 228 |
+
# print("Image feature extractor:", convnext_model_name)
|
| 229 |
+
# return convnext_model
|
| 230 |
+
# except:
|
| 231 |
+
# raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base")
|
| 232 |
+
|
| 233 |
+
# def forward(self, img):
|
| 234 |
+
# #return (batchsize, patch_num, dim)
|
| 235 |
+
# batch_size = img.shape[0]
|
| 236 |
+
# conv_fea = self.conv_features(img)
|
| 237 |
+
# conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1))
|
| 238 |
+
# conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d')
|
| 239 |
+
# h = rearrange(conv_fea,'b n d -> (b n) d')
|
| 240 |
+
# x = self.conv_l1(h)
|
| 241 |
+
# x = F.relu(x)
|
| 242 |
+
# x = self.conv_l2(x)
|
| 243 |
+
# out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
|
| 244 |
+
# out_pool = torch.mean(out_emb,dim=1)
|
| 245 |
+
# return out_emb,out_pool
|
| 246 |
+
|
| 247 |
+
# import open_clip
|
| 248 |
+
# class ModelCLIP(nn.Module):
|
| 249 |
+
# def __init__(self, clip_base_model):
|
| 250 |
+
# super(ModelCLIP, self).__init__()
|
| 251 |
+
# # 根据clip_base_model加载不同的模型
|
| 252 |
+
# if clip_base_model == 'openai_EVA02-B-16':
|
| 253 |
+
# model, _, preprocess = open_clip.create_model_and_transforms('EVA02-B-16', pretrained='merged2b_s8b_b131k')
|
| 254 |
+
|
| 255 |
+
# elif clip_base_model == 'openai_convnext_base_w':
|
| 256 |
+
# model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg')
|
| 257 |
+
|
| 258 |
+
# else:
|
| 259 |
+
# raise ("Invalid model name. Check the config file and pass one of: EVA02-B-16 or convnext_base_w")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ModelEfficientV2(nn.Module):
|
| 264 |
+
def __init__(self, efficientv2_base_model):
|
| 265 |
+
super(ModelEfficientV2, self).__init__()
|
| 266 |
+
self.efficientv2_dict = {"efficientnet_v2_s": models.efficientnet_v2_s(weights='EfficientNet_V2_S_Weights.IMAGENET1K_V1'),}
|
| 267 |
+
self.efficientv2_model = self._get_efficientv2_basemodel(efficientv2_base_model)
|
| 268 |
+
num_ftrs = int(self.efficientv2_model.classifier[-1].in_features)
|
| 269 |
+
self.efficientv2_features = nn.Sequential(*list(self.efficientv2_model.children())[:-2])
|
| 270 |
+
self.efficientv2_l1 = nn.Linear(num_ftrs, num_ftrs)
|
| 271 |
+
self.efficientv2_l2 = nn.Linear(num_ftrs, 768)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _get_efficientv2_basemodel(self, efficientv2_model_name):
|
| 275 |
+
try:
|
| 276 |
+
efficientv2_model = self.efficientv2_dict[efficientv2_model_name]
|
| 277 |
+
print("Image feature extractor:", efficientv2_model_name)
|
| 278 |
+
return efficientv2_model
|
| 279 |
+
except:
|
| 280 |
+
raise ("Invalid model name. Check the config file and pass one of: efficientnetv2_rw_s")
|
| 281 |
+
|
| 282 |
+
def forward(self, img):
|
| 283 |
+
batch_size = img.shape[0]
|
| 284 |
+
efficientv2_fea = self.efficientv2_features(img)
|
| 285 |
+
# efficientv2_fea = F.adaptive_avg_pool2d(efficientv2_fea, (1, 1))
|
| 286 |
+
# print(efficientv2_fea.shape)
|
| 287 |
+
efficientv2_fea = rearrange(efficientv2_fea,'b d n1 n2 -> b (n1 n2) d')
|
| 288 |
+
# print(efficientv2_fea.shape)
|
| 289 |
+
h = rearrange(efficientv2_fea,'b n d -> (b n) d')
|
| 290 |
+
# print(h.shape)
|
| 291 |
+
x = self.efficientv2_l1(h)
|
| 292 |
+
x = F.relu(x)
|
| 293 |
+
x = self.efficientv2_l2(x)
|
| 294 |
+
# print(x.shape)
|
| 295 |
+
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
|
| 296 |
+
out_pool = torch.mean(out_emb,dim=1)
|
| 297 |
+
return out_emb,out_pool
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class ModelDense(nn.Module):
|
| 302 |
+
def __init__(self, dense_base_model):
|
| 303 |
+
super(ModelDense, self).__init__()
|
| 304 |
+
|
| 305 |
+
self.densenet_dict = {"densenet121": models.densenet121(weights='DenseNet121_Weights.IMAGENET1K_V1'),
|
| 306 |
+
"densenet161": models.densenet161(weights='DenseNet161_Weights.IMAGENET1K_V1'),
|
| 307 |
+
"densenet201": models.densenet201(weights='DenseNet201_Weights.IMAGENET1K_V1'),}
|
| 308 |
+
self.densenet = self._get_dense_basemodel(dense_base_model)
|
| 309 |
+
num_ftrs = int(self.densenet.classifier.in_features)
|
| 310 |
+
self.dense_features = self.densenet.features
|
| 311 |
+
self.dense_l1 = nn.Linear(num_ftrs, num_ftrs)
|
| 312 |
+
self.dense_l2 = nn.Linear(num_ftrs, 768)
|
| 313 |
+
|
| 314 |
+
def _get_dense_basemodel(self, dense_base_model):
|
| 315 |
+
try:
|
| 316 |
+
dense_model = self.densenet_dict[dense_base_model]
|
| 317 |
+
print("Image feature extractor:", dense_base_model)
|
| 318 |
+
return dense_model
|
| 319 |
+
except:
|
| 320 |
+
raise ("Invalid model name. Check the config file and pass one of: densenet121 or densenet161")
|
| 321 |
+
|
| 322 |
+
def forward(self, img):
|
| 323 |
+
batch_size = img.shape[0]
|
| 324 |
+
dense_fea = self.dense_features(img)#N, 1024, 7,7
|
| 325 |
+
dense_fea = rearrange(dense_fea,'b d n1 n2 -> b (n1 n2) d')
|
| 326 |
+
h = rearrange(dense_fea,'b n d -> (b n) d')
|
| 327 |
+
x = self.dense_l1(h)
|
| 328 |
+
x = F.relu(x)
|
| 329 |
+
x = self.dense_l2(x)
|
| 330 |
+
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
|
| 331 |
+
out_pool = torch.mean(out_emb,dim=1)
|
| 332 |
+
return out_emb,out_pool
|
| 333 |
+
|
| 334 |
+
class TQN_Model(nn.Module):
|
| 335 |
+
def __init__(self,
|
| 336 |
+
embed_dim: int = 768,
|
| 337 |
+
class_num: int = 1,
|
| 338 |
+
lam: list = [1, 0]
|
| 339 |
+
):
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.d_model = embed_dim
|
| 342 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 343 |
+
# decoder_layer = TransformerDecoderLayer(self.d_model, 4, 1024,
|
| 344 |
+
# 0.1, 'relu',normalize_before=True)
|
| 345 |
+
decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024,
|
| 346 |
+
0.1, 'relu', True, lam)
|
| 347 |
+
self.decoder_norm = nn.LayerNorm(self.d_model)
|
| 348 |
+
# self.decoder = TransformerDecoder(decoder_layer, 4, self.decoder_norm,
|
| 349 |
+
# return_intermediate=False)
|
| 350 |
+
self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm,
|
| 351 |
+
return_intermediate=False)
|
| 352 |
+
|
| 353 |
+
self.dropout_feas = nn.Dropout(0.1)
|
| 354 |
+
|
| 355 |
+
# class_num = 2
|
| 356 |
+
self.mlp_head = nn.Sequential( # nn.LayerNorm(768),
|
| 357 |
+
nn.Linear(embed_dim, class_num)
|
| 358 |
+
)
|
| 359 |
+
self.apply(self._init_weights)
|
| 360 |
+
|
| 361 |
+
@staticmethod
|
| 362 |
+
def _init_weights(module):
|
| 363 |
+
if isinstance(module, nn.Linear):
|
| 364 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 365 |
+
|
| 366 |
+
elif isinstance(module, nn.MultiheadAttention):
|
| 367 |
+
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
|
| 368 |
+
module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
|
| 369 |
+
|
| 370 |
+
elif isinstance(module, nn.Embedding):
|
| 371 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 372 |
+
if module.padding_idx is not None:
|
| 373 |
+
module.weight.data[module.padding_idx].zero_()
|
| 374 |
+
|
| 375 |
+
# def forward(self, image_features, text_features):
|
| 376 |
+
# #image_features (batch_size,patch_num,dim)
|
| 377 |
+
# #text_features (query_num,dim)
|
| 378 |
+
# batch_size = image_features.shape[0]
|
| 379 |
+
# image_features = image_features.transpose(0,1)
|
| 380 |
+
# text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
|
| 381 |
+
# image_features = self.decoder_norm(image_features)
|
| 382 |
+
# text_features = self.decoder_norm(text_features)
|
| 383 |
+
|
| 384 |
+
# # features = self.decoder(text_features, image_features,
|
| 385 |
+
# # memory_key_padding_mask=None, pos=None, query_pos=None)
|
| 386 |
+
|
| 387 |
+
# image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
|
| 388 |
+
# features = self.decoderV1(text_features, image_features, image_features_pool,
|
| 389 |
+
# memory_key_padding_mask=None, pos=None, query_pos=None)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# features = self.dropout_feas(features).transpose(0,1) #b,embed_dim
|
| 394 |
+
# out = self.mlp_head(features) #(batch_size, query_num)
|
| 395 |
+
# # out = out.squeeze(-1)
|
| 396 |
+
# return out
|
| 397 |
+
|
| 398 |
+
def forward(self, image_features, text_features, return_atten = False):
|
| 399 |
+
#image_features (batch_size,patch_num,dim)
|
| 400 |
+
#text_features (query_num,dim)
|
| 401 |
+
batch_size = image_features.shape[0]
|
| 402 |
+
image_features = image_features.transpose(0,1)
|
| 403 |
+
text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
|
| 404 |
+
image_features = self.decoder_norm(image_features)
|
| 405 |
+
text_features = self.decoder_norm(text_features)
|
| 406 |
+
|
| 407 |
+
image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
|
| 408 |
+
features,atten_map = self.decoderV1(text_features, image_features, image_features_pool,
|
| 409 |
+
memory_key_padding_mask=None, pos=None, query_pos=None)
|
| 410 |
+
features = self.dropout_feas(features).transpose(0,1) #b,embed_dim
|
| 411 |
+
out = self.mlp_head(features) #(batch_size, query_num)
|
| 412 |
+
if return_atten:
|
| 413 |
+
return out, atten_map
|
| 414 |
+
else:
|
| 415 |
+
return out
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class TQN_Model_Ensemble(nn.Module):
|
| 420 |
+
def __init__(self,
|
| 421 |
+
embed_dim: int = 768,
|
| 422 |
+
class_num: int = 1,
|
| 423 |
+
lam: list = [1, 0]
|
| 424 |
+
):
|
| 425 |
+
super().__init__()
|
| 426 |
+
self.d_model = embed_dim
|
| 427 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 428 |
+
decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024,
|
| 429 |
+
0.1, 'relu', True, lam)
|
| 430 |
+
self.decoder_norm = nn.LayerNorm(self.d_model)
|
| 431 |
+
self.decoder_norm_1 = nn.LayerNorm(self.d_model)
|
| 432 |
+
self.decoder_norm_2 = nn.LayerNorm(self.d_model)
|
| 433 |
+
self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm,
|
| 434 |
+
return_intermediate=False)
|
| 435 |
+
self.decoderV1_1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_1,
|
| 436 |
+
return_intermediate=False)
|
| 437 |
+
self.decoderV1_2 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_2,
|
| 438 |
+
return_intermediate=False)
|
| 439 |
+
|
| 440 |
+
self.dropout_feas = nn.Dropout(0.1)
|
| 441 |
+
|
| 442 |
+
# class_num = 2
|
| 443 |
+
self.mlp_head = nn.Sequential(nn.Linear(embed_dim, class_num))
|
| 444 |
+
self.mlp_head_1 = nn.Sequential(nn.Linear(embed_dim, class_num))
|
| 445 |
+
self.mlp_head_2 = nn.Sequential(nn.Linear(embed_dim, class_num))
|
| 446 |
+
self.apply(self._init_weights)
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def _init_weights(module):
|
| 450 |
+
if isinstance(module, nn.Linear):
|
| 451 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 452 |
+
|
| 453 |
+
elif isinstance(module, nn.MultiheadAttention):
|
| 454 |
+
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
|
| 455 |
+
module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
|
| 456 |
+
|
| 457 |
+
elif isinstance(module, nn.Embedding):
|
| 458 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 459 |
+
if module.padding_idx is not None:
|
| 460 |
+
module.weight.data[module.padding_idx].zero_()
|
| 461 |
+
|
| 462 |
+
def forward(self, image_features, text_features, return_atten = False):
|
| 463 |
+
|
| 464 |
+
batch_size = image_features.shape[0]
|
| 465 |
+
image_features = image_features.transpose(0,1)
|
| 466 |
+
text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
|
| 467 |
+
image_features = self.decoder_norm(image_features)
|
| 468 |
+
image_features_1 = self.decoder_norm_1(image_features)
|
| 469 |
+
image_features_2 = self.decoder_norm_2(image_features)
|
| 470 |
+
|
| 471 |
+
text_features = self.decoder_norm(text_features)
|
| 472 |
+
text_features_1 = self.decoder_norm_1(text_features)
|
| 473 |
+
text_features_2 = self.decoder_norm_2(text_features)
|
| 474 |
+
|
| 475 |
+
image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
|
| 476 |
+
image_features_pool_1 = torch.mean(image_features_1,dim=0).unsqueeze(0)
|
| 477 |
+
image_features_pool_2 = torch.mean(image_features_2,dim=0).unsqueeze(0)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
features,atten_map = self.decoderV1(text_features, image_features, image_features_pool,
|
| 481 |
+
memory_key_padding_mask=None, pos=None, query_pos=None)
|
| 482 |
+
features = self.dropout_feas(features).transpose(0,1)
|
| 483 |
+
out = self.mlp_head(features)
|
| 484 |
+
|
| 485 |
+
features_1,atten_map_1 = self.decoderV1_1(text_features_1, image_features_1, image_features_pool_1,
|
| 486 |
+
memory_key_padding_mask=None, pos=None, query_pos=None)
|
| 487 |
+
features_1 = self.dropout_feas(features_1).transpose(0,1)
|
| 488 |
+
out_1 = self.mlp_head_1(features_1)
|
| 489 |
+
|
| 490 |
+
features_2,atten_map_2 = self.decoderV1_2(text_features_2, image_features_2, image_features_pool_2,
|
| 491 |
+
memory_key_padding_mask=None, pos=None, query_pos=None)
|
| 492 |
+
features_2 = self.dropout_feas(features_2).transpose(0,1)
|
| 493 |
+
out_2 = self.mlp_head_2(features_2)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
out_stack = torch.stack([out, out_1, out_2])
|
| 497 |
+
out = torch.mean(out_stack, dim=0)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
if return_atten:
|
| 505 |
+
return out, atten_map
|
| 506 |
+
else:
|
| 507 |
+
return out
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# MIMIC时,batch_size=32, query_num=41, patch_num=256, dim=768
|
| 512 |
+
# img 256, 32, 768
|
| 513 |
+
# txt 1, 32, 768
|
| 514 |
+
# query41, 32, 768
|
| 515 |
+
# fts 41, 32, 768
|
| 516 |
+
# out 41, 32, 1
|
| 517 |
+
# 未经过sigmoid!计算loss时sigmoid!
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
if __name__ == "__main__":
|
| 521 |
+
|
| 522 |
+
#torch 1.10.2 to torch 1.12.1
|
| 523 |
+
#torchvision-0.11.3 to torchvision-0.13.1
|
| 524 |
+
|
| 525 |
+
# image = torch.randn(1, 3, 224, 224)
|
| 526 |
+
# image_encoder = ModelRes(res_base_model = 'resnet50')
|
| 527 |
+
# # image_encoder = ModelDense(dense_base_model = 'densenet121')
|
| 528 |
+
# # image_encoder = ModelViT(vit_base_model = 'vit_b_32')
|
| 529 |
+
# image_encoder(image)
|
| 530 |
+
|
| 531 |
+
# image = torch.randn(256, 1, 768)
|
| 532 |
+
# query = torch.randn(41, 768)
|
| 533 |
+
# model = TQN_Model()
|
| 534 |
+
# out = model(image, query)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
# img = torch.randn(1,3,512,512)
|
| 538 |
+
img = torch.randn(2,3,224,224)
|
| 539 |
+
# model = ModelConvNeXt(convnext_base_model = 'convnext-base')
|
| 540 |
+
# model = ModelEfficientV2(efficientv2_base_model = 'efficientnet_v2_s')
|
| 541 |
+
model = ModelRes(res_base_model = 'resnet50_openai')
|
| 542 |
+
out_emb, out_pool = model(img)
|
| 543 |
+
|
| 544 |
+
print(out_emb.size(), out_pool.size())
|
| 545 |
+
|
| 546 |
+
|
models/resnet.py
ADDED
|
@@ -0,0 +1,1382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Type, Any, Callable, Union, List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from torchvision.transforms._presets import ImageClassification
|
| 9 |
+
from torchvision.utils import _log_api_usage_once
|
| 10 |
+
from torchvision.models._api import WeightsEnum, Weights
|
| 11 |
+
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
| 12 |
+
from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param
|
| 13 |
+
import math
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
|
| 19 |
+
class LoRALayer(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Base lora class
|
| 22 |
+
"""
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
r,
|
| 26 |
+
lora_alpha,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.r = r
|
| 30 |
+
self.lora_alpha = lora_alpha
|
| 31 |
+
# Mark the weight as unmerged
|
| 32 |
+
self.merged = False
|
| 33 |
+
|
| 34 |
+
def reset_parameters(self):
|
| 35 |
+
raise NotImplementedError
|
| 36 |
+
|
| 37 |
+
def train(self, mode:bool = True):
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
def eval(self):
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LoRALinear(LoRALayer):
|
| 45 |
+
def __init__(self, r, lora_alpha, linear_layer):
|
| 46 |
+
"""
|
| 47 |
+
LoRA class for nn.Linear class
|
| 48 |
+
:param r: low rank dimension
|
| 49 |
+
:param lora_alpha: scaling factor
|
| 50 |
+
:param linear_layer: target nn.Linear layer for applying Lora
|
| 51 |
+
"""
|
| 52 |
+
super().__init__(r, lora_alpha)
|
| 53 |
+
self.linear = linear_layer
|
| 54 |
+
|
| 55 |
+
in_features = self.linear.in_features
|
| 56 |
+
out_features = self.linear.out_features
|
| 57 |
+
|
| 58 |
+
# Lora configuration
|
| 59 |
+
self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features)))
|
| 60 |
+
self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r)))
|
| 61 |
+
self.scaling = self.lora_alpha / self.r
|
| 62 |
+
self.reset_parameters()
|
| 63 |
+
|
| 64 |
+
def reset_parameters(self):
|
| 65 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 66 |
+
nn.init.zeros_(self.lora_B)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def train(self, mode:bool = True):
|
| 70 |
+
self.linear.train(mode)
|
| 71 |
+
if self.merged:
|
| 72 |
+
self.linear.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
|
| 73 |
+
self.merged = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def eval(self):
|
| 77 |
+
self.linear.eval()
|
| 78 |
+
if not self.merged:
|
| 79 |
+
self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
| 80 |
+
self.merged = True
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
if not self.merged:
|
| 85 |
+
result = F.linear(x, self.linear.weight, bias=self.linear.bias)
|
| 86 |
+
out = (x @ self.lora_A.T @ self.lora_B.T)
|
| 87 |
+
result += out
|
| 88 |
+
return result
|
| 89 |
+
else:
|
| 90 |
+
return F.linear(x, self.linear.weight, bias=self.linear.bias)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# class LoraConv2d(LoRALayer):
|
| 94 |
+
# def __init__(self, r, lora_alpha, conv_layer):
|
| 95 |
+
# """
|
| 96 |
+
# LoRA class for nn.Conv2d class
|
| 97 |
+
# """
|
| 98 |
+
# super().__init__(r, lora_alpha)
|
| 99 |
+
# self.conv = conv_layer
|
| 100 |
+
|
| 101 |
+
# in_channels = self.conv.in_channels
|
| 102 |
+
# out_channels = self.conv.out_channels
|
| 103 |
+
# kernel_size = self.conv.kernel_size[0]
|
| 104 |
+
|
| 105 |
+
# # lora configuration
|
| 106 |
+
# self.lora_A = nn.Parameter(
|
| 107 |
+
# self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
|
| 108 |
+
# )
|
| 109 |
+
# self.lora_B = nn.Parameter(
|
| 110 |
+
# self.conv.weight.new_zeros((out_channels * kernel_size, r * kernel_size))
|
| 111 |
+
# )
|
| 112 |
+
# self.scaling = self.lora_alpha / self.r
|
| 113 |
+
# self.reset_parameters()
|
| 114 |
+
|
| 115 |
+
# def reset_parameters(self):
|
| 116 |
+
# nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 117 |
+
# nn.init.zeros_(self.lora_B)
|
| 118 |
+
|
| 119 |
+
# def train(self, mode: bool = True):
|
| 120 |
+
# self.conv.train(mode)
|
| 121 |
+
# if self.merged:
|
| 122 |
+
# # Make sure that the weights are not merged
|
| 123 |
+
# self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
|
| 124 |
+
# self.merged = False
|
| 125 |
+
|
| 126 |
+
# def eval(self):
|
| 127 |
+
# self.conv.eval()
|
| 128 |
+
# if not self.merged:
|
| 129 |
+
# # Merge the weights and mark it
|
| 130 |
+
# self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
|
| 131 |
+
# self.merged = True
|
| 132 |
+
|
| 133 |
+
# def forward(self, x):
|
| 134 |
+
# if not self.merged:
|
| 135 |
+
# return F.conv2d(
|
| 136 |
+
# x,
|
| 137 |
+
# self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
|
| 138 |
+
# self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation, self.conv.groups
|
| 139 |
+
# )
|
| 140 |
+
# return self.conv(x)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class LoraConv2d(nn.Conv2d):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
r: int,
|
| 148 |
+
lora_alpha: float,
|
| 149 |
+
in_channels: int,
|
| 150 |
+
out_channels: int,
|
| 151 |
+
kernel_size: _size_2_t,
|
| 152 |
+
stride: _size_2_t = 1,
|
| 153 |
+
padding: Union[str, _size_2_t] = 0,
|
| 154 |
+
dilation: _size_2_t = 1,
|
| 155 |
+
groups: int = 1,
|
| 156 |
+
bias: bool = True,
|
| 157 |
+
padding_mode: str = 'zeros', # TODO: refine this type
|
| 158 |
+
device=None,
|
| 159 |
+
dtype=None
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
LoRA class for nn.Conv2d class
|
| 163 |
+
"""
|
| 164 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
|
| 165 |
+
self.r = r
|
| 166 |
+
self.lora_alpha = lora_alpha
|
| 167 |
+
# lora configuration
|
| 168 |
+
self.lora_A = nn.Parameter(
|
| 169 |
+
self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
|
| 170 |
+
)
|
| 171 |
+
self.lora_B = nn.Parameter(
|
| 172 |
+
self.weight.new_zeros((out_channels * kernel_size, r * kernel_size))
|
| 173 |
+
)
|
| 174 |
+
self.scaling = self.lora_alpha / self.r
|
| 175 |
+
self.reset_parameters_lora()
|
| 176 |
+
self.merged = False
|
| 177 |
+
self.drop_lora_rate = 0.9
|
| 178 |
+
|
| 179 |
+
def reset_parameters_lora(self):
|
| 180 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 181 |
+
nn.init.zeros_(self.lora_B)
|
| 182 |
+
|
| 183 |
+
def train(self, mode: bool = True):
|
| 184 |
+
super().train(mode)
|
| 185 |
+
if self.merged:
|
| 186 |
+
# Make sure that the weights are not merged
|
| 187 |
+
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
| 188 |
+
self.merged = False
|
| 189 |
+
|
| 190 |
+
def eval(self):
|
| 191 |
+
super().eval()
|
| 192 |
+
if not self.merged:
|
| 193 |
+
# Merge the weights and mark it
|
| 194 |
+
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
| 195 |
+
self.merged = True
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
# 产生一个随机数
|
| 199 |
+
# drop_rate = random.random()
|
| 200 |
+
|
| 201 |
+
# # 训练过程中以一定的概率不使用lora
|
| 202 |
+
# if drop_rate <= self.drop_lora_rate and self.training:
|
| 203 |
+
# return F.conv2d(
|
| 204 |
+
# x,
|
| 205 |
+
# self.weight,
|
| 206 |
+
# self.bias, self.stride, self.padding, self.dilation, self.groups
|
| 207 |
+
# )
|
| 208 |
+
# else:
|
| 209 |
+
return F.conv2d(
|
| 210 |
+
x,
|
| 211 |
+
self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
|
| 212 |
+
self.bias, self.stride, self.padding, self.dilation, self.groups
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class MultiLoRALinear(LoRALayer):
|
| 218 |
+
def __init__(self, r, lora_alpha, linear_layer,lora_num):
|
| 219 |
+
"""
|
| 220 |
+
LoRA class for nn.Linear class
|
| 221 |
+
:param r: low rank dimension
|
| 222 |
+
:param lora_alpha: scaling factor
|
| 223 |
+
:param linear_layer: target nn.Linear layer for applying Lora
|
| 224 |
+
"""
|
| 225 |
+
super().__init__(r,lora_alpha)
|
| 226 |
+
self.linear = linear_layer
|
| 227 |
+
self.lora_num = lora_num
|
| 228 |
+
self.r_list = r
|
| 229 |
+
|
| 230 |
+
in_features = self.linear.in_features
|
| 231 |
+
out_features = self.linear.out_features
|
| 232 |
+
|
| 233 |
+
# Lora configuration
|
| 234 |
+
self.lora_A_list = nn.ParameterList([nn.Parameter(self.linear.weight.new_zeros((self.r_list[th], in_features))) for th in range(self.lora_num)])
|
| 235 |
+
self.lora_B_list = nn.ParameterList([nn.Parameter(self.linear.weight.new_zeros((out_features, self.r_list[th]))) for th in range(self.lora_num)])
|
| 236 |
+
# self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features)))
|
| 237 |
+
# self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r)))
|
| 238 |
+
self.scaling = [self.lora_alpha / self.r_list[th] for th in range(self.lora_num)]
|
| 239 |
+
self.reset_parameters()
|
| 240 |
+
|
| 241 |
+
def reset_parameters(self):
|
| 242 |
+
for th in range(self.lora_num):
|
| 243 |
+
nn.init.kaiming_uniform_(self.lora_A_list[th], a=math.sqrt(5))
|
| 244 |
+
nn.init.zeros_(self.lora_B_list[th])
|
| 245 |
+
|
| 246 |
+
def train(self, mode:bool = True):
|
| 247 |
+
self.linear.train(mode)
|
| 248 |
+
|
| 249 |
+
def eval(self):
|
| 250 |
+
self.linear.eval()
|
| 251 |
+
|
| 252 |
+
def forward(self, x, weights):
|
| 253 |
+
if not self.merged:
|
| 254 |
+
result = F.linear(x, self.linear.weight, bias=self.linear.bias) # (247, batch, 768)
|
| 255 |
+
out_stack = torch.stack([(x @ self.lora_A_list[th].T @ self.lora_B_list[th].T) * self.scaling[th] for th in range(self.lora_num)], dim=2) # (2353,16,3,768)
|
| 256 |
+
# (247, batch, lora_num, 768)
|
| 257 |
+
# weights = weights.unsqueeze(0).unsqueeze(-1)
|
| 258 |
+
# (1, batch, lora_num, 1)
|
| 259 |
+
# out = torch.sum(out_stack * weights,dim=2)
|
| 260 |
+
out = torch.sum(out_stack, dim=2)
|
| 261 |
+
# (247, batch, 768)
|
| 262 |
+
result += out
|
| 263 |
+
# (247, batch, 768)
|
| 264 |
+
return result
|
| 265 |
+
else:
|
| 266 |
+
return F.linear(x, self.linear.weight, bias=self.linear.bias)
|
| 267 |
+
|
| 268 |
+
class MultiLoraConv2d(LoRALayer):
|
| 269 |
+
def __init__(self, r, lora_alpha, conv_layer, num_task):
|
| 270 |
+
"""
|
| 271 |
+
LoRA class for nn.Conv2d class
|
| 272 |
+
"""
|
| 273 |
+
super().__init__(r, lora_alpha)
|
| 274 |
+
self.conv = conv_layer
|
| 275 |
+
self.num_task = num_task
|
| 276 |
+
|
| 277 |
+
in_channels = self.conv.in_channels
|
| 278 |
+
out_channels = self.conv.out_channels
|
| 279 |
+
kernel_size = self.conv.kernel_size[0]
|
| 280 |
+
|
| 281 |
+
# lora configuration
|
| 282 |
+
self.lora_A_list = nn.ParameterList([nn.Parameter(self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))) for th in range(num_task)])
|
| 283 |
+
self.lora_B_list = nn.ParameterList([nn.Parameter(self.conv.weight.new_zeros((out_channels * kernel_size, r * kernel_size))) for th in range(num_task)])
|
| 284 |
+
|
| 285 |
+
self.scaling = self.lora_alpha / self.r
|
| 286 |
+
self.reset_parameters()
|
| 287 |
+
|
| 288 |
+
self.merged = False
|
| 289 |
+
self.label_batch = None
|
| 290 |
+
|
| 291 |
+
def reset_parameters(self):
|
| 292 |
+
for th in range(self.num_task):
|
| 293 |
+
nn.init.kaiming_uniform_(self.lora_A_list[th], a=math.sqrt(5))
|
| 294 |
+
nn.init.zeros_(self.lora_B_list[th])
|
| 295 |
+
|
| 296 |
+
def train(self, mode: bool = True):
|
| 297 |
+
self.conv.train(mode)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def eval(self):
|
| 301 |
+
self.conv.eval()
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def forward(self, input_x, alphas=None):
|
| 305 |
+
if not self.merged:
|
| 306 |
+
conv_weight_stack = torch.cat([(self.lora_B_list[th] @ self.lora_A_list[th]).view(self.conv.weight.shape).unsqueeze(0) * self.scaling for th in range(self.num_task)], dim=0)
|
| 307 |
+
|
| 308 |
+
if isinstance(input_x, dict):
|
| 309 |
+
# print('input is dict')
|
| 310 |
+
x, alphas = input_x[0], input_x[1]
|
| 311 |
+
|
| 312 |
+
else:
|
| 313 |
+
x = input_x
|
| 314 |
+
batch_size, c = x.shape[0], x.shape[1]
|
| 315 |
+
# print(alphas)
|
| 316 |
+
if alphas==None:
|
| 317 |
+
print('在lora_fast里才是none')
|
| 318 |
+
agg_weights = self.conv.weight + torch.sum(
|
| 319 |
+
torch.mul(conv_weight_stack.unsqueeze(0), alphas.view(batch_size, -1, 1, 1, 1, 1)), dim=1)
|
| 320 |
+
|
| 321 |
+
agg_weights = agg_weights.view(-1, *agg_weights.shape[-3:])
|
| 322 |
+
x_grouped = x.view(1, -1, *x.shape[-2:])
|
| 323 |
+
|
| 324 |
+
outputs = F.conv2d(x_grouped, agg_weights, self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation, groups=batch_size)
|
| 325 |
+
outputs = outputs.view(batch_size, -1, *outputs.shape[-2:])
|
| 326 |
+
|
| 327 |
+
return outputs
|
| 328 |
+
else:
|
| 329 |
+
return self.conv(x)
|
| 330 |
+
|
| 331 |
+
def merged_weight(self, th): # only for test
|
| 332 |
+
self.conv.weight.data += (self.lora_B_list[th] @ self.lora_A_list[th]).view(self.conv.weight.shape) * self.scaling
|
| 333 |
+
self.merged = True
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
__all__ = [
|
| 338 |
+
"ResNet",
|
| 339 |
+
"ResNet18_Weights",
|
| 340 |
+
"ResNet34_Weights",
|
| 341 |
+
"ResNet50_Weights",
|
| 342 |
+
"ResNet101_Weights",
|
| 343 |
+
"ResNet152_Weights",
|
| 344 |
+
"ResNeXt50_32X4D_Weights",
|
| 345 |
+
"ResNeXt101_32X8D_Weights",
|
| 346 |
+
"ResNeXt101_64X4D_Weights",
|
| 347 |
+
"Wide_ResNet50_2_Weights",
|
| 348 |
+
"Wide_ResNet101_2_Weights",
|
| 349 |
+
"resnet18",
|
| 350 |
+
"resnet34",
|
| 351 |
+
"resnet50",
|
| 352 |
+
"resnet101",
|
| 353 |
+
"resnet152",
|
| 354 |
+
"resnext50_32x4d",
|
| 355 |
+
"resnext101_32x8d",
|
| 356 |
+
"resnext101_64x4d",
|
| 357 |
+
"wide_resnet50_2",
|
| 358 |
+
"wide_resnet101_2",
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
| 363 |
+
"""3x3 convolution with padding"""
|
| 364 |
+
return nn.Conv2d(
|
| 365 |
+
in_planes,
|
| 366 |
+
out_planes,
|
| 367 |
+
kernel_size=3,
|
| 368 |
+
stride=stride,
|
| 369 |
+
padding=dilation,
|
| 370 |
+
groups=groups,
|
| 371 |
+
bias=False,
|
| 372 |
+
dilation=dilation,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
| 377 |
+
"""1x1 convolution"""
|
| 378 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 379 |
+
|
| 380 |
+
def conv3x3_lora(r: int, lora_alpha: float, in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
| 381 |
+
"""3x3 convolution with padding"""
|
| 382 |
+
return LoraConv2d(
|
| 383 |
+
r,lora_alpha,
|
| 384 |
+
in_planes,
|
| 385 |
+
out_planes,
|
| 386 |
+
kernel_size=3,
|
| 387 |
+
stride=stride,
|
| 388 |
+
padding=dilation,
|
| 389 |
+
groups=groups,
|
| 390 |
+
bias=False,
|
| 391 |
+
dilation=dilation,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def conv1x1_lora(r: int, lora_alpha: float, in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
| 395 |
+
"""1x1 convolution"""
|
| 396 |
+
return LoraConv2d(r, lora_alpha, in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class BasicBlock_Lora(nn.Module):
|
| 401 |
+
expansion: int = 1
|
| 402 |
+
|
| 403 |
+
def __init__(
|
| 404 |
+
self,
|
| 405 |
+
inplanes: int,
|
| 406 |
+
planes: int,
|
| 407 |
+
r: int,
|
| 408 |
+
lora_alpha: float,
|
| 409 |
+
stride: int = 1,
|
| 410 |
+
downsample: Optional[nn.Module] = None,
|
| 411 |
+
groups: int = 1,
|
| 412 |
+
base_width: int = 64,
|
| 413 |
+
dilation: int = 1,
|
| 414 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 415 |
+
) -> None:
|
| 416 |
+
super().__init__()
|
| 417 |
+
if norm_layer is None:
|
| 418 |
+
norm_layer = nn.BatchNorm2d
|
| 419 |
+
if groups != 1 or base_width != 64:
|
| 420 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
| 421 |
+
if dilation > 1:
|
| 422 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 423 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 424 |
+
self.conv1 = conv3x3_lora(r, lora_alpha, inplanes, planes, stride)
|
| 425 |
+
self.bn1 = norm_layer(planes)
|
| 426 |
+
self.relu = nn.ReLU(inplace=True)
|
| 427 |
+
self.conv2 = conv3x3_lora(r, lora_alpha, planes, planes)
|
| 428 |
+
self.bn2 = norm_layer(planes)
|
| 429 |
+
self.downsample = downsample
|
| 430 |
+
self.stride = stride
|
| 431 |
+
|
| 432 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 433 |
+
identity = x
|
| 434 |
+
|
| 435 |
+
out = self.conv1(x)
|
| 436 |
+
out = self.bn1(out)
|
| 437 |
+
out = self.relu(out)
|
| 438 |
+
|
| 439 |
+
out = self.conv2(out)
|
| 440 |
+
out = self.bn2(out)
|
| 441 |
+
|
| 442 |
+
if self.downsample is not None:
|
| 443 |
+
identity = self.downsample(x)
|
| 444 |
+
|
| 445 |
+
out += identity
|
| 446 |
+
out = self.relu(out)
|
| 447 |
+
|
| 448 |
+
return out
|
| 449 |
+
|
| 450 |
+
class BasicBlock(nn.Module):
|
| 451 |
+
expansion: int = 1
|
| 452 |
+
|
| 453 |
+
def __init__(
|
| 454 |
+
self,
|
| 455 |
+
inplanes: int,
|
| 456 |
+
planes: int,
|
| 457 |
+
stride: int = 1,
|
| 458 |
+
downsample: Optional[nn.Module] = None,
|
| 459 |
+
groups: int = 1,
|
| 460 |
+
base_width: int = 64,
|
| 461 |
+
dilation: int = 1,
|
| 462 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 463 |
+
) -> None:
|
| 464 |
+
super().__init__()
|
| 465 |
+
if norm_layer is None:
|
| 466 |
+
norm_layer = nn.BatchNorm2d
|
| 467 |
+
if groups != 1 or base_width != 64:
|
| 468 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
| 469 |
+
if dilation > 1:
|
| 470 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 471 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 472 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 473 |
+
self.bn1 = norm_layer(planes)
|
| 474 |
+
self.relu = nn.ReLU(inplace=True)
|
| 475 |
+
self.conv2 = conv3x3(planes, planes)
|
| 476 |
+
self.bn2 = norm_layer(planes)
|
| 477 |
+
self.downsample = downsample
|
| 478 |
+
self.stride = stride
|
| 479 |
+
|
| 480 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 481 |
+
identity = x
|
| 482 |
+
|
| 483 |
+
out = self.conv1(x)
|
| 484 |
+
out = self.bn1(out)
|
| 485 |
+
out = self.relu(out)
|
| 486 |
+
|
| 487 |
+
out = self.conv2(out)
|
| 488 |
+
out = self.bn2(out)
|
| 489 |
+
|
| 490 |
+
if self.downsample is not None:
|
| 491 |
+
identity = self.downsample(x)
|
| 492 |
+
|
| 493 |
+
out += identity
|
| 494 |
+
out = self.relu(out)
|
| 495 |
+
|
| 496 |
+
return out
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class Bottleneck(nn.Module):
|
| 500 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 501 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 502 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 503 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 504 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 505 |
+
|
| 506 |
+
expansion: int = 4
|
| 507 |
+
|
| 508 |
+
def __init__(
|
| 509 |
+
self,
|
| 510 |
+
inplanes: int,
|
| 511 |
+
planes: int,
|
| 512 |
+
stride: int = 1,
|
| 513 |
+
downsample: Optional[nn.Module] = None,
|
| 514 |
+
groups: int = 1,
|
| 515 |
+
base_width: int = 64,
|
| 516 |
+
dilation: int = 1,
|
| 517 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 518 |
+
) -> None:
|
| 519 |
+
super().__init__()
|
| 520 |
+
if norm_layer is None:
|
| 521 |
+
norm_layer = nn.BatchNorm2d
|
| 522 |
+
width = int(planes * (base_width / 64.0)) * groups
|
| 523 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 524 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 525 |
+
self.bn1 = norm_layer(width)
|
| 526 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 527 |
+
self.bn2 = norm_layer(width)
|
| 528 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 529 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 530 |
+
self.relu = nn.ReLU(inplace=True)
|
| 531 |
+
self.downsample = downsample
|
| 532 |
+
self.stride = stride
|
| 533 |
+
|
| 534 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 535 |
+
identity = x
|
| 536 |
+
|
| 537 |
+
out = self.conv1(x)
|
| 538 |
+
out = self.bn1(out)
|
| 539 |
+
out = self.relu(out)
|
| 540 |
+
|
| 541 |
+
out = self.conv2(out)
|
| 542 |
+
out = self.bn2(out)
|
| 543 |
+
out = self.relu(out)
|
| 544 |
+
|
| 545 |
+
out = self.conv3(out)
|
| 546 |
+
out = self.bn3(out)
|
| 547 |
+
|
| 548 |
+
if self.downsample is not None:
|
| 549 |
+
identity = self.downsample(x)
|
| 550 |
+
|
| 551 |
+
out += identity
|
| 552 |
+
out = self.relu(out)
|
| 553 |
+
|
| 554 |
+
return out
|
| 555 |
+
|
| 556 |
+
class Bottleneck_Lora(nn.Module):
|
| 557 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 558 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 559 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 560 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 561 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 562 |
+
|
| 563 |
+
expansion: int = 4
|
| 564 |
+
|
| 565 |
+
def __init__(
|
| 566 |
+
self,
|
| 567 |
+
inplanes: int,
|
| 568 |
+
planes: int,
|
| 569 |
+
r: int,
|
| 570 |
+
lora_alpha: float,
|
| 571 |
+
stride: int = 1,
|
| 572 |
+
downsample: Optional[nn.Module] = None,
|
| 573 |
+
groups: int = 1,
|
| 574 |
+
base_width: int = 64,
|
| 575 |
+
dilation: int = 1,
|
| 576 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 577 |
+
) -> None:
|
| 578 |
+
super().__init__()
|
| 579 |
+
if norm_layer is None:
|
| 580 |
+
norm_layer = nn.BatchNorm2d
|
| 581 |
+
width = int(planes * (base_width / 64.0)) * groups
|
| 582 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 583 |
+
self.conv1 = conv1x1_lora(r, lora_alpha, inplanes, width)
|
| 584 |
+
self.bn1 = norm_layer(width)
|
| 585 |
+
self.conv2 = conv3x3_lora(r, lora_alpha, width, width, stride, groups, dilation)
|
| 586 |
+
self.bn2 = norm_layer(width)
|
| 587 |
+
self.conv3 = conv1x1_lora(r, lora_alpha, width, planes * self.expansion)
|
| 588 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 589 |
+
self.relu = nn.ReLU(inplace=True)
|
| 590 |
+
self.downsample = downsample
|
| 591 |
+
self.stride = stride
|
| 592 |
+
|
| 593 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 594 |
+
identity = x
|
| 595 |
+
|
| 596 |
+
out = self.conv1(x)
|
| 597 |
+
out = self.bn1(out)
|
| 598 |
+
out = self.relu(out)
|
| 599 |
+
|
| 600 |
+
out = self.conv2(out)
|
| 601 |
+
out = self.bn2(out)
|
| 602 |
+
out = self.relu(out)
|
| 603 |
+
|
| 604 |
+
out = self.conv3(out)
|
| 605 |
+
out = self.bn3(out)
|
| 606 |
+
|
| 607 |
+
if self.downsample is not None:
|
| 608 |
+
identity = self.downsample(x)
|
| 609 |
+
|
| 610 |
+
out += identity
|
| 611 |
+
out = self.relu(out)
|
| 612 |
+
|
| 613 |
+
return out
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class ResNet(nn.Module):
|
| 617 |
+
def __init__(
|
| 618 |
+
self,
|
| 619 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 620 |
+
layers: List[int],
|
| 621 |
+
num_classes: int = 1000,
|
| 622 |
+
zero_init_residual: bool = False,
|
| 623 |
+
groups: int = 1,
|
| 624 |
+
width_per_group: int = 64,
|
| 625 |
+
replace_stride_with_dilation: Optional[List[bool]] = None,
|
| 626 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 627 |
+
) -> None:
|
| 628 |
+
super().__init__()
|
| 629 |
+
_log_api_usage_once(self)
|
| 630 |
+
if norm_layer is None:
|
| 631 |
+
norm_layer = nn.BatchNorm2d
|
| 632 |
+
self._norm_layer = norm_layer
|
| 633 |
+
|
| 634 |
+
self.inplanes = 64
|
| 635 |
+
self.dilation = 1
|
| 636 |
+
if replace_stride_with_dilation is None:
|
| 637 |
+
# each element in the tuple indicates if we should replace
|
| 638 |
+
# the 2x2 stride with a dilated convolution instead
|
| 639 |
+
replace_stride_with_dilation = [False, False, False]
|
| 640 |
+
if len(replace_stride_with_dilation) != 3:
|
| 641 |
+
raise ValueError(
|
| 642 |
+
"replace_stride_with_dilation should be None "
|
| 643 |
+
f"or a 3-element tuple, got {replace_stride_with_dilation}"
|
| 644 |
+
)
|
| 645 |
+
self.groups = groups
|
| 646 |
+
self.base_width = width_per_group
|
| 647 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
| 648 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 649 |
+
self.relu = nn.ReLU(inplace=True)
|
| 650 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 651 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 652 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
| 653 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
| 654 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
| 655 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 656 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 657 |
+
|
| 658 |
+
for m in self.modules():
|
| 659 |
+
if isinstance(m, nn.Conv2d):
|
| 660 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 661 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 662 |
+
nn.init.constant_(m.weight, 1)
|
| 663 |
+
nn.init.constant_(m.bias, 0)
|
| 664 |
+
|
| 665 |
+
# Zero-initialize the last BN in each residual branch,
|
| 666 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 667 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 668 |
+
if zero_init_residual:
|
| 669 |
+
for m in self.modules():
|
| 670 |
+
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
|
| 671 |
+
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
| 672 |
+
elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
|
| 673 |
+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
| 674 |
+
|
| 675 |
+
def _make_layer(
|
| 676 |
+
self,
|
| 677 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 678 |
+
planes: int,
|
| 679 |
+
blocks: int,
|
| 680 |
+
stride: int = 1,
|
| 681 |
+
dilate: bool = False,
|
| 682 |
+
) -> nn.Sequential:
|
| 683 |
+
norm_layer = self._norm_layer
|
| 684 |
+
downsample = None
|
| 685 |
+
previous_dilation = self.dilation
|
| 686 |
+
if dilate:
|
| 687 |
+
self.dilation *= stride
|
| 688 |
+
stride = 1
|
| 689 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 690 |
+
downsample = nn.Sequential(
|
| 691 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 692 |
+
norm_layer(planes * block.expansion),
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
layers = []
|
| 696 |
+
layers.append(
|
| 697 |
+
block(
|
| 698 |
+
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
|
| 699 |
+
)
|
| 700 |
+
)
|
| 701 |
+
self.inplanes = planes * block.expansion
|
| 702 |
+
for _ in range(1, blocks):
|
| 703 |
+
layers.append(
|
| 704 |
+
block(
|
| 705 |
+
self.inplanes,
|
| 706 |
+
planes,
|
| 707 |
+
groups=self.groups,
|
| 708 |
+
base_width=self.base_width,
|
| 709 |
+
dilation=self.dilation,
|
| 710 |
+
norm_layer=norm_layer,
|
| 711 |
+
)
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
return nn.Sequential(*layers)
|
| 715 |
+
|
| 716 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 717 |
+
# See note [TorchScript super()]
|
| 718 |
+
x = self.conv1(x)
|
| 719 |
+
x = self.bn1(x)
|
| 720 |
+
x = self.relu(x)
|
| 721 |
+
x = self.maxpool(x)
|
| 722 |
+
|
| 723 |
+
x = self.layer1(x)
|
| 724 |
+
x = self.layer2(x)
|
| 725 |
+
x = self.layer3(x)
|
| 726 |
+
x = self.layer4(x)
|
| 727 |
+
|
| 728 |
+
x = self.avgpool(x)
|
| 729 |
+
x = torch.flatten(x, 1)
|
| 730 |
+
x = self.fc(x)
|
| 731 |
+
|
| 732 |
+
return x
|
| 733 |
+
|
| 734 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 735 |
+
return self._forward_impl(x)
|
| 736 |
+
|
| 737 |
+
class ResNet_Lora(nn.Module):
|
| 738 |
+
def __init__(
|
| 739 |
+
self,
|
| 740 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 741 |
+
layers: List[int],
|
| 742 |
+
r: int,
|
| 743 |
+
lora_alpha: float,
|
| 744 |
+
num_classes: int = 1000,
|
| 745 |
+
zero_init_residual: bool = False,
|
| 746 |
+
groups: int = 1,
|
| 747 |
+
width_per_group: int = 64,
|
| 748 |
+
replace_stride_with_dilation: Optional[List[bool]] = None,
|
| 749 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 750 |
+
) -> None:
|
| 751 |
+
super().__init__()
|
| 752 |
+
_log_api_usage_once(self)
|
| 753 |
+
if norm_layer is None:
|
| 754 |
+
norm_layer = nn.BatchNorm2d
|
| 755 |
+
self._norm_layer = norm_layer
|
| 756 |
+
|
| 757 |
+
self.inplanes = 64
|
| 758 |
+
self.dilation = 1
|
| 759 |
+
if replace_stride_with_dilation is None:
|
| 760 |
+
# each element in the tuple indicates if we should replace
|
| 761 |
+
# the 2x2 stride with a dilated convolution instead
|
| 762 |
+
replace_stride_with_dilation = [False, False, False]
|
| 763 |
+
if len(replace_stride_with_dilation) != 3:
|
| 764 |
+
raise ValueError(
|
| 765 |
+
"replace_stride_with_dilation should be None "
|
| 766 |
+
f"or a 3-element tuple, got {replace_stride_with_dilation}"
|
| 767 |
+
)
|
| 768 |
+
self.groups = groups
|
| 769 |
+
self.base_width = width_per_group
|
| 770 |
+
self.r = r
|
| 771 |
+
self.lora_alpha = lora_alpha
|
| 772 |
+
self.conv1 = LoraConv2d(self.r, self.lora_alpha, 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
| 773 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 774 |
+
self.relu = nn.ReLU(inplace=True)
|
| 775 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 776 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 777 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
| 778 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
| 779 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
| 780 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 781 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 782 |
+
|
| 783 |
+
for m in self.modules():
|
| 784 |
+
if isinstance(m, nn.Conv2d):
|
| 785 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 786 |
+
elif isinstance(m, LoraConv2d):
|
| 787 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 788 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 789 |
+
nn.init.constant_(m.weight, 1)
|
| 790 |
+
nn.init.constant_(m.bias, 0)
|
| 791 |
+
|
| 792 |
+
# Zero-initialize the last BN in each residual branch,
|
| 793 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 794 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 795 |
+
if zero_init_residual:
|
| 796 |
+
for m in self.modules():
|
| 797 |
+
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
|
| 798 |
+
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
| 799 |
+
elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
|
| 800 |
+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
| 801 |
+
|
| 802 |
+
def _make_layer(
|
| 803 |
+
self,
|
| 804 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 805 |
+
planes: int,
|
| 806 |
+
blocks: int,
|
| 807 |
+
stride: int = 1,
|
| 808 |
+
dilate: bool = False,
|
| 809 |
+
) -> nn.Sequential:
|
| 810 |
+
norm_layer = self._norm_layer
|
| 811 |
+
downsample = None
|
| 812 |
+
previous_dilation = self.dilation
|
| 813 |
+
if dilate:
|
| 814 |
+
self.dilation *= stride
|
| 815 |
+
stride = 1
|
| 816 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 817 |
+
downsample = nn.Sequential(
|
| 818 |
+
conv1x1_lora(self.r, self.lora_alpha, self.inplanes, planes * block.expansion, stride),
|
| 819 |
+
norm_layer(planes * block.expansion),
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
layers = []
|
| 823 |
+
layers.append(
|
| 824 |
+
block(
|
| 825 |
+
self.inplanes, planes, self.r, self.lora_alpha, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
|
| 826 |
+
)
|
| 827 |
+
)
|
| 828 |
+
self.inplanes = planes * block.expansion
|
| 829 |
+
for _ in range(1, blocks):
|
| 830 |
+
layers.append(
|
| 831 |
+
block(
|
| 832 |
+
self.inplanes,
|
| 833 |
+
planes,
|
| 834 |
+
self.r,
|
| 835 |
+
self.lora_alpha,
|
| 836 |
+
groups=self.groups,
|
| 837 |
+
base_width=self.base_width,
|
| 838 |
+
dilation=self.dilation,
|
| 839 |
+
norm_layer=norm_layer,
|
| 840 |
+
)
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
return nn.Sequential(*layers)
|
| 844 |
+
|
| 845 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 846 |
+
# See note [TorchScript super()]
|
| 847 |
+
x = self.conv1(x)
|
| 848 |
+
x = self.bn1(x)
|
| 849 |
+
x = self.relu(x)
|
| 850 |
+
x = self.maxpool(x)
|
| 851 |
+
|
| 852 |
+
x = self.layer1(x)
|
| 853 |
+
x = self.layer2(x)
|
| 854 |
+
x = self.layer3(x)
|
| 855 |
+
x = self.layer4(x)
|
| 856 |
+
|
| 857 |
+
x = self.avgpool(x)
|
| 858 |
+
x = torch.flatten(x, 1)
|
| 859 |
+
x = self.fc(x)
|
| 860 |
+
|
| 861 |
+
return x
|
| 862 |
+
|
| 863 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 864 |
+
return self._forward_impl(x)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def _resnet(
|
| 869 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 870 |
+
layers: List[int],
|
| 871 |
+
weights: Optional[WeightsEnum],
|
| 872 |
+
progress: bool,
|
| 873 |
+
**kwargs: Any,
|
| 874 |
+
) -> ResNet:
|
| 875 |
+
if weights is not None:
|
| 876 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 877 |
+
|
| 878 |
+
model = ResNet(block, layers, **kwargs)
|
| 879 |
+
|
| 880 |
+
if weights is not None:
|
| 881 |
+
model.load_state_dict(weights.get_state_dict(progress=progress))
|
| 882 |
+
|
| 883 |
+
return model
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
def _resnet_lora(
|
| 887 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 888 |
+
layers: List[int],
|
| 889 |
+
r: int,
|
| 890 |
+
lora_alpha: float,
|
| 891 |
+
weights: Optional[WeightsEnum],
|
| 892 |
+
progress: bool,
|
| 893 |
+
**kwargs: Any,
|
| 894 |
+
) -> ResNet_Lora:
|
| 895 |
+
if weights is not None:
|
| 896 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 897 |
+
|
| 898 |
+
model = ResNet_Lora(block, layers, r, lora_alpha, **kwargs)
|
| 899 |
+
if weights is not None:
|
| 900 |
+
missing_keys, unexpected_keys = model.load_state_dict(weights.get_state_dict(progress=progress), strict=False)
|
| 901 |
+
|
| 902 |
+
for key_name in missing_keys:
|
| 903 |
+
if 'lora_A' in key_name or 'lora_B' in key_name:
|
| 904 |
+
pass
|
| 905 |
+
else:
|
| 906 |
+
raise ValueError(f'{key_name} in missing keys')
|
| 907 |
+
|
| 908 |
+
if unexpected_keys != []:
|
| 909 |
+
raise ValueError(f'Have unexpected keys {unexpected_keys}')
|
| 910 |
+
|
| 911 |
+
return model
|
| 912 |
+
|
| 913 |
+
_COMMON_META = {
|
| 914 |
+
"min_size": (1, 1),
|
| 915 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
class ResNet18_Weights(WeightsEnum):
|
| 920 |
+
IMAGENET1K_V1 = Weights(
|
| 921 |
+
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
|
| 922 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 923 |
+
meta={
|
| 924 |
+
**_COMMON_META,
|
| 925 |
+
"num_params": 11689512,
|
| 926 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
|
| 927 |
+
"_metrics": {
|
| 928 |
+
"ImageNet-1K": {
|
| 929 |
+
"acc@1": 69.758,
|
| 930 |
+
"acc@5": 89.078,
|
| 931 |
+
}
|
| 932 |
+
},
|
| 933 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 934 |
+
},
|
| 935 |
+
)
|
| 936 |
+
DEFAULT = IMAGENET1K_V1
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
class ResNet34_Weights(WeightsEnum):
|
| 940 |
+
IMAGENET1K_V1 = Weights(
|
| 941 |
+
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
|
| 942 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 943 |
+
meta={
|
| 944 |
+
**_COMMON_META,
|
| 945 |
+
"num_params": 21797672,
|
| 946 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
|
| 947 |
+
"_metrics": {
|
| 948 |
+
"ImageNet-1K": {
|
| 949 |
+
"acc@1": 73.314,
|
| 950 |
+
"acc@5": 91.420,
|
| 951 |
+
}
|
| 952 |
+
},
|
| 953 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 954 |
+
},
|
| 955 |
+
)
|
| 956 |
+
DEFAULT = IMAGENET1K_V1
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
class ResNet50_Weights(WeightsEnum):
|
| 960 |
+
IMAGENET1K_V1 = Weights(
|
| 961 |
+
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
|
| 962 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 963 |
+
meta={
|
| 964 |
+
**_COMMON_META,
|
| 965 |
+
"num_params": 25557032,
|
| 966 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
|
| 967 |
+
"_metrics": {
|
| 968 |
+
"ImageNet-1K": {
|
| 969 |
+
"acc@1": 76.130,
|
| 970 |
+
"acc@5": 92.862,
|
| 971 |
+
}
|
| 972 |
+
},
|
| 973 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 974 |
+
},
|
| 975 |
+
)
|
| 976 |
+
IMAGENET1K_V2 = Weights(
|
| 977 |
+
url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
|
| 978 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 979 |
+
meta={
|
| 980 |
+
**_COMMON_META,
|
| 981 |
+
"num_params": 25557032,
|
| 982 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
|
| 983 |
+
"_metrics": {
|
| 984 |
+
"ImageNet-1K": {
|
| 985 |
+
"acc@1": 80.858,
|
| 986 |
+
"acc@5": 95.434,
|
| 987 |
+
}
|
| 988 |
+
},
|
| 989 |
+
"_docs": """
|
| 990 |
+
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
|
| 991 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 992 |
+
""",
|
| 993 |
+
},
|
| 994 |
+
)
|
| 995 |
+
DEFAULT = IMAGENET1K_V2
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
class ResNet101_Weights(WeightsEnum):
|
| 999 |
+
IMAGENET1K_V1 = Weights(
|
| 1000 |
+
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
|
| 1001 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 1002 |
+
meta={
|
| 1003 |
+
**_COMMON_META,
|
| 1004 |
+
"num_params": 44549160,
|
| 1005 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
|
| 1006 |
+
"_metrics": {
|
| 1007 |
+
"ImageNet-1K": {
|
| 1008 |
+
"acc@1": 77.374,
|
| 1009 |
+
"acc@5": 93.546,
|
| 1010 |
+
}
|
| 1011 |
+
},
|
| 1012 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 1013 |
+
},
|
| 1014 |
+
)
|
| 1015 |
+
IMAGENET1K_V2 = Weights(
|
| 1016 |
+
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
|
| 1017 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 1018 |
+
meta={
|
| 1019 |
+
**_COMMON_META,
|
| 1020 |
+
"num_params": 44549160,
|
| 1021 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
|
| 1022 |
+
"_metrics": {
|
| 1023 |
+
"ImageNet-1K": {
|
| 1024 |
+
"acc@1": 81.886,
|
| 1025 |
+
"acc@5": 95.780,
|
| 1026 |
+
}
|
| 1027 |
+
},
|
| 1028 |
+
"_docs": """
|
| 1029 |
+
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
|
| 1030 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 1031 |
+
""",
|
| 1032 |
+
},
|
| 1033 |
+
)
|
| 1034 |
+
DEFAULT = IMAGENET1K_V2
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
class ResNet152_Weights(WeightsEnum):
|
| 1038 |
+
IMAGENET1K_V1 = Weights(
|
| 1039 |
+
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
|
| 1040 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 1041 |
+
meta={
|
| 1042 |
+
**_COMMON_META,
|
| 1043 |
+
"num_params": 60192808,
|
| 1044 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
|
| 1045 |
+
"_metrics": {
|
| 1046 |
+
"ImageNet-1K": {
|
| 1047 |
+
"acc@1": 78.312,
|
| 1048 |
+
"acc@5": 94.046,
|
| 1049 |
+
}
|
| 1050 |
+
},
|
| 1051 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 1052 |
+
},
|
| 1053 |
+
)
|
| 1054 |
+
IMAGENET1K_V2 = Weights(
|
| 1055 |
+
url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
|
| 1056 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 1057 |
+
meta={
|
| 1058 |
+
**_COMMON_META,
|
| 1059 |
+
"num_params": 60192808,
|
| 1060 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
|
| 1061 |
+
"_metrics": {
|
| 1062 |
+
"ImageNet-1K": {
|
| 1063 |
+
"acc@1": 82.284,
|
| 1064 |
+
"acc@5": 96.002,
|
| 1065 |
+
}
|
| 1066 |
+
},
|
| 1067 |
+
"_docs": """
|
| 1068 |
+
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
|
| 1069 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 1070 |
+
""",
|
| 1071 |
+
},
|
| 1072 |
+
)
|
| 1073 |
+
DEFAULT = IMAGENET1K_V2
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
class ResNeXt50_32X4D_Weights(WeightsEnum):
|
| 1077 |
+
IMAGENET1K_V1 = Weights(
|
| 1078 |
+
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
| 1079 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 1080 |
+
meta={
|
| 1081 |
+
**_COMMON_META,
|
| 1082 |
+
"num_params": 25028904,
|
| 1083 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
|
| 1084 |
+
"_metrics": {
|
| 1085 |
+
"ImageNet-1K": {
|
| 1086 |
+
"acc@1": 77.618,
|
| 1087 |
+
"acc@5": 93.698,
|
| 1088 |
+
}
|
| 1089 |
+
},
|
| 1090 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 1091 |
+
},
|
| 1092 |
+
)
|
| 1093 |
+
IMAGENET1K_V2 = Weights(
|
| 1094 |
+
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
|
| 1095 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 1096 |
+
meta={
|
| 1097 |
+
**_COMMON_META,
|
| 1098 |
+
"num_params": 25028904,
|
| 1099 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
|
| 1100 |
+
"_metrics": {
|
| 1101 |
+
"ImageNet-1K": {
|
| 1102 |
+
"acc@1": 81.198,
|
| 1103 |
+
"acc@5": 95.340,
|
| 1104 |
+
}
|
| 1105 |
+
},
|
| 1106 |
+
"_docs": """
|
| 1107 |
+
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
|
| 1108 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 1109 |
+
""",
|
| 1110 |
+
},
|
| 1111 |
+
)
|
| 1112 |
+
DEFAULT = IMAGENET1K_V2
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
class ResNeXt101_32X8D_Weights(WeightsEnum):
|
| 1116 |
+
IMAGENET1K_V1 = Weights(
|
| 1117 |
+
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
| 1118 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 1119 |
+
meta={
|
| 1120 |
+
**_COMMON_META,
|
| 1121 |
+
"num_params": 88791336,
|
| 1122 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
|
| 1123 |
+
"_metrics": {
|
| 1124 |
+
"ImageNet-1K": {
|
| 1125 |
+
"acc@1": 79.312,
|
| 1126 |
+
"acc@5": 94.526,
|
| 1127 |
+
}
|
| 1128 |
+
},
|
| 1129 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 1130 |
+
},
|
| 1131 |
+
)
|
| 1132 |
+
IMAGENET1K_V2 = Weights(
|
| 1133 |
+
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
|
| 1134 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 1135 |
+
meta={
|
| 1136 |
+
**_COMMON_META,
|
| 1137 |
+
"num_params": 88791336,
|
| 1138 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
|
| 1139 |
+
"_metrics": {
|
| 1140 |
+
"ImageNet-1K": {
|
| 1141 |
+
"acc@1": 82.834,
|
| 1142 |
+
"acc@5": 96.228,
|
| 1143 |
+
}
|
| 1144 |
+
},
|
| 1145 |
+
"_docs": """
|
| 1146 |
+
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
|
| 1147 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 1148 |
+
""",
|
| 1149 |
+
},
|
| 1150 |
+
)
|
| 1151 |
+
DEFAULT = IMAGENET1K_V2
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
class ResNeXt101_64X4D_Weights(WeightsEnum):
|
| 1155 |
+
IMAGENET1K_V1 = Weights(
|
| 1156 |
+
url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
|
| 1157 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 1158 |
+
meta={
|
| 1159 |
+
**_COMMON_META,
|
| 1160 |
+
"num_params": 83455272,
|
| 1161 |
+
"recipe": "https://github.com/pytorch/vision/pull/5935",
|
| 1162 |
+
"_metrics": {
|
| 1163 |
+
"ImageNet-1K": {
|
| 1164 |
+
"acc@1": 83.246,
|
| 1165 |
+
"acc@5": 96.454,
|
| 1166 |
+
}
|
| 1167 |
+
},
|
| 1168 |
+
"_docs": """
|
| 1169 |
+
These weights were trained from scratch by using TorchVision's `new training recipe
|
| 1170 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 1171 |
+
""",
|
| 1172 |
+
},
|
| 1173 |
+
)
|
| 1174 |
+
DEFAULT = IMAGENET1K_V1
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
class Wide_ResNet50_2_Weights(WeightsEnum):
|
| 1178 |
+
IMAGENET1K_V1 = Weights(
|
| 1179 |
+
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
| 1180 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 1181 |
+
meta={
|
| 1182 |
+
**_COMMON_META,
|
| 1183 |
+
"num_params": 68883240,
|
| 1184 |
+
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
|
| 1185 |
+
"_metrics": {
|
| 1186 |
+
"ImageNet-1K": {
|
| 1187 |
+
"acc@1": 78.468,
|
| 1188 |
+
"acc@5": 94.086,
|
| 1189 |
+
}
|
| 1190 |
+
},
|
| 1191 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 1192 |
+
},
|
| 1193 |
+
)
|
| 1194 |
+
IMAGENET1K_V2 = Weights(
|
| 1195 |
+
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
|
| 1196 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 1197 |
+
meta={
|
| 1198 |
+
**_COMMON_META,
|
| 1199 |
+
"num_params": 68883240,
|
| 1200 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
|
| 1201 |
+
"_metrics": {
|
| 1202 |
+
"ImageNet-1K": {
|
| 1203 |
+
"acc@1": 81.602,
|
| 1204 |
+
"acc@5": 95.758,
|
| 1205 |
+
}
|
| 1206 |
+
},
|
| 1207 |
+
"_docs": """
|
| 1208 |
+
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
|
| 1209 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 1210 |
+
""",
|
| 1211 |
+
},
|
| 1212 |
+
)
|
| 1213 |
+
DEFAULT = IMAGENET1K_V2
|
| 1214 |
+
|
| 1215 |
+
|
| 1216 |
+
class Wide_ResNet101_2_Weights(WeightsEnum):
|
| 1217 |
+
IMAGENET1K_V1 = Weights(
|
| 1218 |
+
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
|
| 1219 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 1220 |
+
meta={
|
| 1221 |
+
**_COMMON_META,
|
| 1222 |
+
"num_params": 126886696,
|
| 1223 |
+
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
|
| 1224 |
+
"_metrics": {
|
| 1225 |
+
"ImageNet-1K": {
|
| 1226 |
+
"acc@1": 78.848,
|
| 1227 |
+
"acc@5": 94.284,
|
| 1228 |
+
}
|
| 1229 |
+
},
|
| 1230 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 1231 |
+
},
|
| 1232 |
+
)
|
| 1233 |
+
IMAGENET1K_V2 = Weights(
|
| 1234 |
+
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
|
| 1235 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 1236 |
+
meta={
|
| 1237 |
+
**_COMMON_META,
|
| 1238 |
+
"num_params": 126886696,
|
| 1239 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
|
| 1240 |
+
"_metrics": {
|
| 1241 |
+
"ImageNet-1K": {
|
| 1242 |
+
"acc@1": 82.510,
|
| 1243 |
+
"acc@5": 96.020,
|
| 1244 |
+
}
|
| 1245 |
+
},
|
| 1246 |
+
"_docs": """
|
| 1247 |
+
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
|
| 1248 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 1249 |
+
""",
|
| 1250 |
+
},
|
| 1251 |
+
)
|
| 1252 |
+
DEFAULT = IMAGENET1K_V2
|
| 1253 |
+
|
| 1254 |
+
|
| 1255 |
+
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
|
| 1256 |
+
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1257 |
+
weights = ResNet18_Weights.verify(weights)
|
| 1258 |
+
|
| 1259 |
+
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
|
| 1260 |
+
|
| 1261 |
+
|
| 1262 |
+
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
|
| 1263 |
+
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1264 |
+
weights = ResNet34_Weights.verify(weights)
|
| 1265 |
+
|
| 1266 |
+
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
|
| 1267 |
+
|
| 1268 |
+
|
| 1269 |
+
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
|
| 1270 |
+
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1271 |
+
weights = ResNet50_Weights.verify(weights)
|
| 1272 |
+
|
| 1273 |
+
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
|
| 1274 |
+
|
| 1275 |
+
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
|
| 1276 |
+
def resnet50_lora(*, r: int, lora_alpha: float, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1277 |
+
weights = ResNet50_Weights.verify(weights)
|
| 1278 |
+
|
| 1279 |
+
return _resnet_lora(Bottleneck_Lora, [3, 4, 6, 3], r, lora_alpha, weights, progress, **kwargs)
|
| 1280 |
+
|
| 1281 |
+
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
|
| 1282 |
+
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1283 |
+
weights = ResNet101_Weights.verify(weights)
|
| 1284 |
+
|
| 1285 |
+
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
|
| 1286 |
+
|
| 1287 |
+
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
|
| 1288 |
+
def resnet101_lora(*, r: int, lora_alpha: float, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1289 |
+
weights = ResNet101_Weights.verify(weights)
|
| 1290 |
+
|
| 1291 |
+
return _resnet_lora(Bottleneck_Lora, [3, 4, 23, 3], r, lora_alpha, weights, progress, **kwargs)
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
|
| 1295 |
+
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1296 |
+
weights = ResNet152_Weights.verify(weights)
|
| 1297 |
+
|
| 1298 |
+
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
|
| 1299 |
+
|
| 1300 |
+
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
|
| 1301 |
+
def resnet152_lora(*, r: int, lora_alpha: float, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
|
| 1302 |
+
weights = ResNet152_Weights.verify(weights)
|
| 1303 |
+
|
| 1304 |
+
return _resnet_lora(Bottleneck_Lora, [3, 8, 36, 3], r, lora_alpha, weights, progress, **kwargs)
|
| 1305 |
+
|
| 1306 |
+
|
| 1307 |
+
@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
|
| 1308 |
+
def resnext50_32x4d(
|
| 1309 |
+
*, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1310 |
+
) -> ResNet:
|
| 1311 |
+
weights = ResNeXt50_32X4D_Weights.verify(weights)
|
| 1312 |
+
|
| 1313 |
+
_ovewrite_named_param(kwargs, "groups", 32)
|
| 1314 |
+
_ovewrite_named_param(kwargs, "width_per_group", 4)
|
| 1315 |
+
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
|
| 1316 |
+
|
| 1317 |
+
|
| 1318 |
+
@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
|
| 1319 |
+
def resnext101_32x8d(
|
| 1320 |
+
*, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1321 |
+
) -> ResNet:
|
| 1322 |
+
weights = ResNeXt101_32X8D_Weights.verify(weights)
|
| 1323 |
+
|
| 1324 |
+
_ovewrite_named_param(kwargs, "groups", 32)
|
| 1325 |
+
_ovewrite_named_param(kwargs, "width_per_group", 8)
|
| 1326 |
+
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
def resnext101_64x4d(
|
| 1330 |
+
*, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1331 |
+
) -> ResNet:
|
| 1332 |
+
weights = ResNeXt101_64X4D_Weights.verify(weights)
|
| 1333 |
+
|
| 1334 |
+
_ovewrite_named_param(kwargs, "groups", 64)
|
| 1335 |
+
_ovewrite_named_param(kwargs, "width_per_group", 4)
|
| 1336 |
+
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
|
| 1337 |
+
|
| 1338 |
+
|
| 1339 |
+
@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
|
| 1340 |
+
def wide_resnet50_2(
|
| 1341 |
+
*, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1342 |
+
) -> ResNet:
|
| 1343 |
+
weights = Wide_ResNet50_2_Weights.verify(weights)
|
| 1344 |
+
|
| 1345 |
+
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
|
| 1346 |
+
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
|
| 1350 |
+
def wide_resnet101_2(
|
| 1351 |
+
*, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1352 |
+
) -> ResNet:
|
| 1353 |
+
weights = Wide_ResNet101_2_Weights.verify(weights)
|
| 1354 |
+
|
| 1355 |
+
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
|
| 1356 |
+
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
# The dictionary below is internal implementation detail and will be removed in v0.15
|
| 1360 |
+
from torchvision.models._utils import _ModelURLs
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
model_urls = _ModelURLs(
|
| 1364 |
+
{
|
| 1365 |
+
"resnet18": ResNet18_Weights.IMAGENET1K_V1.url,
|
| 1366 |
+
"resnet34": ResNet34_Weights.IMAGENET1K_V1.url,
|
| 1367 |
+
"resnet50": ResNet50_Weights.IMAGENET1K_V1.url,
|
| 1368 |
+
"resnet101": ResNet101_Weights.IMAGENET1K_V1.url,
|
| 1369 |
+
"resnet152": ResNet152_Weights.IMAGENET1K_V1.url,
|
| 1370 |
+
"resnext50_32x4d": ResNeXt50_32X4D_Weights.IMAGENET1K_V1.url,
|
| 1371 |
+
"resnext101_32x8d": ResNeXt101_32X8D_Weights.IMAGENET1K_V1.url,
|
| 1372 |
+
"wide_resnet50_2": Wide_ResNet50_2_Weights.IMAGENET1K_V1.url,
|
| 1373 |
+
"wide_resnet101_2": Wide_ResNet101_2_Weights.IMAGENET1K_V1.url,
|
| 1374 |
+
}
|
| 1375 |
+
)
|
| 1376 |
+
|
| 1377 |
+
|
| 1378 |
+
if __name__ == '__main__':
|
| 1379 |
+
model = resnet50_lora(r=16, lora_alpha=16, weights='ResNet50_Weights.IMAGENET1K_V2')
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
|
models/transformer_decoder.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
import pickle as cp
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
# from visualizer import get_local
|
| 9 |
+
|
| 10 |
+
class TransformerDecoder(nn.Module):
|
| 11 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 14 |
+
self.num_layers = num_layers
|
| 15 |
+
self.norm = norm
|
| 16 |
+
self.return_intermediate = return_intermediate
|
| 17 |
+
|
| 18 |
+
def forward(self,tgt, memory,
|
| 19 |
+
tgt_mask: Optional[Tensor] = None,
|
| 20 |
+
memory_mask: Optional[Tensor] = None,
|
| 21 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 22 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 23 |
+
pos: Optional[Tensor] = None,
|
| 24 |
+
query_pos: Optional[Tensor] = None):
|
| 25 |
+
output = tgt
|
| 26 |
+
T,B,C = memory.shape
|
| 27 |
+
intermediate = []
|
| 28 |
+
for n,layer in enumerate(self.layers):
|
| 29 |
+
residual=True
|
| 30 |
+
output,ws = layer(output, memory, tgt_mask=tgt_mask,
|
| 31 |
+
memory_mask=memory_mask,
|
| 32 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 33 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 34 |
+
pos=pos, query_pos=query_pos,residual=residual)
|
| 35 |
+
if self.return_intermediate:
|
| 36 |
+
intermediate.append(self.norm(output))
|
| 37 |
+
if self.norm is not None:
|
| 38 |
+
output = self.norm(output)
|
| 39 |
+
if self.return_intermediate:
|
| 40 |
+
intermediate.pop()
|
| 41 |
+
intermediate.append(output)
|
| 42 |
+
|
| 43 |
+
if self.return_intermediate:
|
| 44 |
+
return torch.stack(intermediate)
|
| 45 |
+
return output
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TransformerDecoderLayer(nn.Module):
|
| 49 |
+
def __init__(self, d_model, nhead, dim_feedforward=1024, dropout=0.1,
|
| 50 |
+
activation="relu", normalize_before=False):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 53 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 54 |
+
# Implementation of Feedforward model
|
| 55 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 56 |
+
self.dropout = nn.Dropout(dropout)
|
| 57 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 58 |
+
|
| 59 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 60 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 61 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 62 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 63 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 64 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 65 |
+
|
| 66 |
+
self.activation = _get_activation_fn(activation)
|
| 67 |
+
self.normalize_before = normalize_before
|
| 68 |
+
|
| 69 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 70 |
+
return tensor if pos is None else tensor + pos
|
| 71 |
+
|
| 72 |
+
def forward_post(self, tgt, memory,
|
| 73 |
+
tgt_mask: Optional[Tensor] = None,
|
| 74 |
+
memory_mask: Optional[Tensor] = None,
|
| 75 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 76 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 77 |
+
pos: Optional[Tensor] = None,
|
| 78 |
+
query_pos: Optional[Tensor] = None,
|
| 79 |
+
residual=True):
|
| 80 |
+
#tgt: text_features, torch.Size([14, 1, 768])
|
| 81 |
+
#memory: image_features, torch.Size([49, 1, 768])
|
| 82 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 83 |
+
tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
| 84 |
+
key_padding_mask=tgt_key_padding_mask)
|
| 85 |
+
tgt = self.norm1(tgt)
|
| 86 |
+
tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
| 87 |
+
key=self.with_pos_embed(memory, pos),
|
| 88 |
+
need_weights = True,
|
| 89 |
+
value=memory, attn_mask=memory_mask,
|
| 90 |
+
key_padding_mask=memory_key_padding_mask)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# attn_weights [B,NUM_Q,T]
|
| 94 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 95 |
+
tgt = self.norm2(tgt)
|
| 96 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 97 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 98 |
+
tgt = self.norm3(tgt)
|
| 99 |
+
return tgt,ws
|
| 100 |
+
|
| 101 |
+
# @get_local('attn_weights')
|
| 102 |
+
def forward_pre(self, tgt, memory,
|
| 103 |
+
tgt_mask: Optional[Tensor] = None,
|
| 104 |
+
memory_mask: Optional[Tensor] = None,
|
| 105 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 106 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 107 |
+
pos: Optional[Tensor] = None,
|
| 108 |
+
query_pos: Optional[Tensor] = None):
|
| 109 |
+
tgt2 = self.norm1(tgt)
|
| 110 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 111 |
+
tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 112 |
+
key_padding_mask=tgt_key_padding_mask)
|
| 113 |
+
# print('self atten',ws.shape)
|
| 114 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 115 |
+
tgt2 = self.norm2(tgt)
|
| 116 |
+
tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 117 |
+
key=self.with_pos_embed(memory, pos),
|
| 118 |
+
value=memory, attn_mask=memory_mask,
|
| 119 |
+
key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
|
| 120 |
+
# print('self attn_weights',attn_weights.shape)
|
| 121 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 122 |
+
tgt2 = self.norm3(tgt)
|
| 123 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 124 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 125 |
+
return tgt,attn_weights
|
| 126 |
+
|
| 127 |
+
def forward(self, tgt, memory,
|
| 128 |
+
tgt_mask: Optional[Tensor] = None,
|
| 129 |
+
memory_mask: Optional[Tensor] = None,
|
| 130 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 131 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 132 |
+
pos: Optional[Tensor] = None,
|
| 133 |
+
query_pos: Optional[Tensor] = None,
|
| 134 |
+
residual=True):
|
| 135 |
+
if self.normalize_before:
|
| 136 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
| 137 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 138 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
| 139 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
|
| 140 |
+
|
| 141 |
+
class TransformerDecoderV1(nn.Module):
|
| 142 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 145 |
+
self.num_layers = num_layers
|
| 146 |
+
self.norm = norm
|
| 147 |
+
self.return_intermediate = return_intermediate
|
| 148 |
+
|
| 149 |
+
def forward(self,tgt, memory,
|
| 150 |
+
memory_global, #
|
| 151 |
+
tgt_mask: Optional[Tensor] = None,
|
| 152 |
+
memory_mask: Optional[Tensor] = None,
|
| 153 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 154 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 155 |
+
pos: Optional[Tensor] = None,
|
| 156 |
+
query_pos: Optional[Tensor] = None):
|
| 157 |
+
output = tgt
|
| 158 |
+
T,B,C = memory.shape
|
| 159 |
+
intermediate = []
|
| 160 |
+
for n,layer in enumerate(self.layers):
|
| 161 |
+
residual=True
|
| 162 |
+
output,ws = layer(output, memory,
|
| 163 |
+
memory_global, #
|
| 164 |
+
|
| 165 |
+
tgt_mask=tgt_mask,
|
| 166 |
+
memory_mask=memory_mask,
|
| 167 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 168 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 169 |
+
pos=pos, query_pos=query_pos,residual=residual)
|
| 170 |
+
|
| 171 |
+
if self.return_intermediate:
|
| 172 |
+
intermediate.append(self.norm(output))
|
| 173 |
+
if self.norm is not None:
|
| 174 |
+
output = self.norm(output)
|
| 175 |
+
#
|
| 176 |
+
|
| 177 |
+
if self.return_intermediate:
|
| 178 |
+
intermediate.pop()
|
| 179 |
+
intermediate.append(output)
|
| 180 |
+
|
| 181 |
+
if self.return_intermediate:
|
| 182 |
+
return torch.stack(intermediate)
|
| 183 |
+
|
| 184 |
+
return output,ws
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class TransformerDecoderLayerV1(nn.Module):
|
| 188 |
+
def __init__(self, d_model, nhead, dim_feedforward=1024, dropout=0.1,
|
| 189 |
+
activation="relu", normalize_before=False, lam = [1,0]):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 192 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 193 |
+
# Implementation of Feedforward model
|
| 194 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 195 |
+
self.dropout = nn.Dropout(dropout)
|
| 196 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 197 |
+
|
| 198 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 199 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 200 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 201 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 202 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 203 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 204 |
+
|
| 205 |
+
self.activation = _get_activation_fn(activation)
|
| 206 |
+
self.normalize_before = normalize_before
|
| 207 |
+
|
| 208 |
+
# self.lam = lam
|
| 209 |
+
self.lam_l = lam[0]
|
| 210 |
+
self.lam_g = lam[1]
|
| 211 |
+
|
| 212 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 213 |
+
return tensor if pos is None else tensor + pos
|
| 214 |
+
|
| 215 |
+
def forward_post(self, tgt, memory,
|
| 216 |
+
memory_global,
|
| 217 |
+
tgt_mask: Optional[Tensor] = None,
|
| 218 |
+
memory_mask: Optional[Tensor] = None,
|
| 219 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 220 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 221 |
+
pos: Optional[Tensor] = None,
|
| 222 |
+
query_pos: Optional[Tensor] = None,
|
| 223 |
+
residual=True):
|
| 224 |
+
#tgt: text_features, torch.Size([14, 1, 768])
|
| 225 |
+
#memory: image_features, torch.Size([49, 1, 768])
|
| 226 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 227 |
+
tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
| 228 |
+
key_padding_mask=tgt_key_padding_mask)
|
| 229 |
+
tgt = self.norm1(tgt)
|
| 230 |
+
tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
| 231 |
+
key=self.with_pos_embed(memory, pos),
|
| 232 |
+
need_weights = True,
|
| 233 |
+
value=memory, attn_mask=memory_mask,
|
| 234 |
+
key_padding_mask=memory_key_padding_mask)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# attn_weights [B,NUM_Q,T]
|
| 238 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 239 |
+
tgt = self.norm2(tgt)
|
| 240 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 241 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 242 |
+
tgt = self.norm3(tgt)
|
| 243 |
+
return tgt,ws
|
| 244 |
+
|
| 245 |
+
# @get_local('attn_weights')
|
| 246 |
+
def forward_pre(self, tgt, memory,
|
| 247 |
+
memory_global, #
|
| 248 |
+
tgt_mask: Optional[Tensor] = None,
|
| 249 |
+
memory_mask: Optional[Tensor] = None,
|
| 250 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 251 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 252 |
+
pos: Optional[Tensor] = None,
|
| 253 |
+
query_pos: Optional[Tensor] = None):
|
| 254 |
+
tgt2 = self.norm1(tgt)
|
| 255 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 256 |
+
tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 257 |
+
key_padding_mask=tgt_key_padding_mask)
|
| 258 |
+
# print('self atten',ws.shape)
|
| 259 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 260 |
+
tgt2 = self.norm2(tgt)
|
| 261 |
+
if memory.shape[0] == 1:
|
| 262 |
+
tgt2_fine,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 263 |
+
key=self.with_pos_embed(memory, pos),
|
| 264 |
+
value=memory, attn_mask=memory_mask,
|
| 265 |
+
key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
|
| 266 |
+
tgt2 = tgt2_fine
|
| 267 |
+
attn_weights = attn_weights
|
| 268 |
+
|
| 269 |
+
else:
|
| 270 |
+
tgt2_fine,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 271 |
+
key=self.with_pos_embed(memory, pos),
|
| 272 |
+
value=memory, attn_mask=memory_mask,
|
| 273 |
+
key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
|
| 274 |
+
tgt2_global,attn_weights_global = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 275 |
+
key=self.with_pos_embed(memory_global, pos),
|
| 276 |
+
value=memory_global, attn_mask=memory_mask,
|
| 277 |
+
key_padding_mask=memory_key_padding_mask)# (N,num_heads,L,S)
|
| 278 |
+
# tgt2 = tgt2_fine + self.lam * tgt2_global
|
| 279 |
+
tgt2 = tgt2_fine*self.lam_l + tgt2_global*self.lam_g
|
| 280 |
+
# attn_weights = attn_weights + self.lam * attn_weights_global
|
| 281 |
+
attn_weights = attn_weights*self.lam_l + attn_weights_global*self.lam_g
|
| 282 |
+
|
| 283 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 284 |
+
tgt2 = self.norm3(tgt)
|
| 285 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 286 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 287 |
+
|
| 288 |
+
return tgt, attn_weights
|
| 289 |
+
|
| 290 |
+
def forward(self, tgt, memory,
|
| 291 |
+
memory_global,
|
| 292 |
+
tgt_mask: Optional[Tensor] = None,
|
| 293 |
+
memory_mask: Optional[Tensor] = None,
|
| 294 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 295 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 296 |
+
pos: Optional[Tensor] = None,
|
| 297 |
+
query_pos: Optional[Tensor] = None,
|
| 298 |
+
residual=True):
|
| 299 |
+
if self.normalize_before:
|
| 300 |
+
return self.forward_pre(tgt, memory, memory_global, tgt_mask, memory_mask,
|
| 301 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 302 |
+
return self.forward_post(tgt, memory, memory_global, tgt_mask, memory_mask,
|
| 303 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _get_clones(module, N):
|
| 308 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _get_activation_fn(activation):
|
| 313 |
+
"""Return an activation function given a string"""
|
| 314 |
+
if activation == "relu":
|
| 315 |
+
return F.relu
|
| 316 |
+
if activation == "gelu":
|
| 317 |
+
return F.gelu
|
| 318 |
+
if activation == "glu":
|
| 319 |
+
return F.glu
|
| 320 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|