File size: 7,723 Bytes
940092e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os.path as osp
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from CLIP.tokenizer import SimpleTokenizer,tokenize
from huggingface_hub import PyTorchModelHubMixin
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x,_,_ = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class PromptLearner(nn.Module):
def __init__(self,
prompts,
n_ctx, # prompt max len
CSC, # True or False multi prompt
class_token_position, # cls position
clip_model):
super().__init__()
ctx_dim = clip_model.ln_final.weight.shape[0] #
self.ctx={}
for cls in prompts:
for position in class_token_position:
if CSC:
ctx_vectors = torch.empty(len(prompts[cls]), n_ctx, ctx_dim).to(clip_model.device)
else:
ctx_vectors = torch.empty(n_ctx, ctx_dim).to(clip_model.device)
nn.init.normal_(ctx_vectors, std=0.02)
self.ctx['{}_{}'.format(cls,position)]=nn.Parameter(ctx_vectors,requires_grad=True)
self.ctx = nn.ParameterDict(self.ctx) # to be optimized
prompt_prefix = " ".join(["X"] * n_ctx)
_tokenizer = SimpleTokenizer()
prompts_split={cls: [prompt.replace("_", " ") for prompt in prompts[cls]] for cls in prompts}
prompts_lens= {cls: [ len(_tokenizer.encode(prompt)) for prompt in prompts_split[cls]] for cls in prompts_split}
prompts_learnable_tokens = {cls:[prompt_prefix + " " + prompt + "." for prompt in prompts_split[cls]] for cls in prompts_split}
tokenized_prompts = {cls:torch.cat([tokenize(prompt) for prompt in prompts_learnable_tokens[cls]]).to(clip_model.device) for cls in prompts_learnable_tokens}
with torch.no_grad():
embeddings = {cls:clip_model.token_embedding(tokenized_prompts[cls]) for cls in tokenized_prompts}
self.register_embeddings={}
for cls in embeddings:
self.register_embeddings['{}_token_prefix'.format(cls)]=embeddings[cls][:, :1, :]
self.register_embeddings['{}_token_suffix'.format(cls)]=embeddings[cls][:, 1 + n_ctx :, :]
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts
self.prompts_lens = prompts_lens
self.class_token_position = class_token_position
def forward(self):
cls_prompts={}
for cls in self.tokenized_prompts:
prefix = self.register_embeddings['{}_token_prefix'.format(cls)]
suffix = self.register_embeddings['{}_token_suffix'.format(cls)]
cls_prompts[cls]=[]
for position in self.class_token_position:
ctx = self.ctx['{}_{}'.format(cls,position)]
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(len(self.prompts_lens[cls]), -1, -1)
if position == "end":
prompts = torch.cat(
[
prefix, # (n_cls, 1, dim)
ctx, # (n_cls, n_ctx, dim)
suffix, # (n_cls, *, dim)
],
dim=1,
)
elif position == "middle":
half_n_ctx = self.n_ctx // 2
prompts = []
for i in range(len(self.prompts_lens[cls])):
p_len = self.prompts_lens[cls][i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :p_len, :]
suffix_i = suffix[i : i + 1, p_len:, :]
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
else :
assert position == "front"
prompts = []
for i in range(len(self.prompts_lens[cls])):
p_len = self.prompts_lens[cls][i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :p_len, :]
suffix_i = suffix[i : i + 1, p_len:, :]
ctx_i = ctx[i : i + 1, :, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
class_i, # (1, name_len, dim)
ctx_i, # (1, n_ctx, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
cls_prompts[cls].append(prompts)
cls_prompts[cls]=torch.cat(cls_prompts[cls],dim=0)
return cls_prompts
class PromptMaker(nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/Continual-Mega/Continual-Mega",
paper_url="https://arxiv.org/abs/2506.00956"):
def __init__(self,
prompts,
clip_model,
n_ctx: int=8, # prompt max len
CSC: bool= True, # True or False multi prompt
class_token_position: list=['end'], # cls position
):
super().__init__()
assert 'normal' in prompts and 'abnormal' in prompts
for position in class_token_position:
assert position in ['end','middle','front']
self.prompt_learner = PromptLearner(prompts, n_ctx, CSC, class_token_position, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.class_token_position = class_token_position
self.text_encoder = TextEncoder(clip_model)
def forward(self):
prompts = self.prompt_learner()
tokenized_prompts = self.tokenized_prompts
text_features=[]
for cls in prompts:
class_embedding = self.text_encoder(prompts[cls], tokenized_prompts[cls].repeat(len(self.class_token_position),1))
class_embedding = class_embedding.mean(dim=0)
class_embedding = class_embedding / class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1)
return text_features |