Continual-Mega commited on
Commit
370c0d0
·
verified ·
1 Parent(s): d959f89

Upload CLIP/adapter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CLIP/adapter.py +76 -0
CLIP/adapter.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ import torch.nn.functional as F
7
+
8
+ # Residual CLIP Adapter
9
+ class ClipAdapter(nn.Module):
10
+ def __init__(self, c_in, bottleneck=768):
11
+ super(ClipAdapter, self).__init__()
12
+ self.fc1 = nn.Sequential(
13
+ nn.Linear(c_in, bottleneck, bias=False),
14
+ nn.LeakyReLU(inplace=False)
15
+ )
16
+ self.fc2 = nn.Sequential(
17
+ nn.Linear(bottleneck, c_in, bias=False),
18
+ nn.LeakyReLU(inplace=False)
19
+ )
20
+
21
+ def forward(self, x):
22
+ x = self.fc1(x)
23
+ y = self.fc2(x)
24
+ return x, y
25
+
26
+
27
+ class CLIPAD(nn.Module,
28
+ PyTorchModelHubMixin,
29
+ repo_url="https://github.com/Continual-Mega/Continual-Mega",
30
+ paper_url="https://arxiv.org/abs/2506.00956"):
31
+ def __init__(self, clip_model, features):
32
+ super().__init__()
33
+ self.clipmodel = clip_model
34
+ self.image_encoder = clip_model.visual
35
+ self.features = features
36
+ self.adapters = nn.ModuleList( [ClipAdapter(1024, bottleneck=768) for i in range(len(features))] )
37
+
38
+ def forward(self, x):
39
+ x = self.image_encoder.conv1(x)
40
+ x = x.reshape(x.shape[0], x.shape[1], -1)
41
+ x = x.permute(0, 2, 1)
42
+
43
+ x = torch.cat(
44
+ [self.image_encoder.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
45
+ x], dim=1)
46
+ x = x + self.image_encoder.positional_embedding.to(x.dtype)
47
+
48
+ x = self.image_encoder.patch_dropout(x)
49
+ x = self.image_encoder.ln_pre(x)
50
+
51
+ x = x.permute(1, 0, 2)
52
+
53
+ ada_patch_tokens = []
54
+
55
+ for i, res in enumerate(self.image_encoder.transformer.resblocks):
56
+ x, _ = res(x, attn_mask=None)
57
+ if (i + 1) in self.features:
58
+ adapt_med, adapt_out = self.adapters[self.features.index(i+1)](x)
59
+
60
+ x = 0.9 * x + 0.1 * adapt_out
61
+ ada_patch_tokens.append(adapt_med)
62
+
63
+ x = x.permute(1, 0, 2)
64
+
65
+ ada_patch_tokens = [ada_patch_tokens[t].permute(1, 0, 2) for t in range(len(ada_patch_tokens))]
66
+
67
+ pooled, tokens = self.image_encoder._global_pool(x)
68
+ pooled = self.image_encoder.ln_post(pooled)
69
+
70
+ if self.image_encoder.proj is not None:
71
+ pooled = pooled @ self.image_encoder.proj
72
+
73
+ return pooled, ada_patch_tokens
74
+
75
+
76
+