|
|
import os.path as osp |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from torch.nn import functional as F |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
from CLIP.tokenizer import SimpleTokenizer,tokenize |
|
|
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
class TextEncoder(nn.Module): |
|
|
def __init__(self, clip_model): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.transformer = clip_model.transformer |
|
|
self.positional_embedding = clip_model.positional_embedding |
|
|
self.ln_final = clip_model.ln_final |
|
|
self.text_projection = clip_model.text_projection |
|
|
|
|
|
|
|
|
def forward(self, prompts, tokenized_prompts): |
|
|
|
|
|
x = prompts + self.positional_embedding |
|
|
x = x.permute(1, 0, 2) |
|
|
x,_,_ = self.transformer(x) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = self.ln_final(x) |
|
|
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class PromptLearner(nn.Module): |
|
|
def __init__(self, |
|
|
prompts, |
|
|
n_ctx, |
|
|
CSC, |
|
|
class_token_position, |
|
|
clip_model): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
ctx_dim = clip_model.ln_final.weight.shape[0] |
|
|
|
|
|
self.ctx={} |
|
|
|
|
|
for cls in prompts: |
|
|
for position in class_token_position: |
|
|
if CSC: |
|
|
ctx_vectors = torch.empty(len(prompts[cls]), n_ctx, ctx_dim).to(clip_model.device) |
|
|
else: |
|
|
ctx_vectors = torch.empty(n_ctx, ctx_dim).to(clip_model.device) |
|
|
nn.init.normal_(ctx_vectors, std=0.02) |
|
|
self.ctx['{}_{}'.format(cls,position)]=nn.Parameter(ctx_vectors,requires_grad=True) |
|
|
|
|
|
self.ctx = nn.ParameterDict(self.ctx) |
|
|
|
|
|
prompt_prefix = " ".join(["X"] * n_ctx) |
|
|
|
|
|
_tokenizer = SimpleTokenizer() |
|
|
|
|
|
prompts_split={cls: [prompt.replace("_", " ") for prompt in prompts[cls]] for cls in prompts} |
|
|
|
|
|
prompts_lens= {cls: [ len(_tokenizer.encode(prompt)) for prompt in prompts_split[cls]] for cls in prompts_split} |
|
|
|
|
|
prompts_learnable_tokens = {cls:[prompt_prefix + " " + prompt + "." for prompt in prompts_split[cls]] for cls in prompts_split} |
|
|
|
|
|
tokenized_prompts = {cls:torch.cat([tokenize(prompt) for prompt in prompts_learnable_tokens[cls]]).to(clip_model.device) for cls in prompts_learnable_tokens} |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = {cls:clip_model.token_embedding(tokenized_prompts[cls]) for cls in tokenized_prompts} |
|
|
|
|
|
self.register_embeddings={} |
|
|
|
|
|
for cls in embeddings: |
|
|
self.register_embeddings['{}_token_prefix'.format(cls)]=embeddings[cls][:, :1, :] |
|
|
self.register_embeddings['{}_token_suffix'.format(cls)]=embeddings[cls][:, 1 + n_ctx :, :] |
|
|
|
|
|
self.n_ctx = n_ctx |
|
|
self.tokenized_prompts = tokenized_prompts |
|
|
self.prompts_lens = prompts_lens |
|
|
self.class_token_position = class_token_position |
|
|
|
|
|
|
|
|
def forward(self): |
|
|
cls_prompts={} |
|
|
|
|
|
for cls in self.tokenized_prompts: |
|
|
|
|
|
prefix = self.register_embeddings['{}_token_prefix'.format(cls)] |
|
|
suffix = self.register_embeddings['{}_token_suffix'.format(cls)] |
|
|
|
|
|
cls_prompts[cls]=[] |
|
|
|
|
|
for position in self.class_token_position: |
|
|
|
|
|
ctx = self.ctx['{}_{}'.format(cls,position)] |
|
|
if ctx.dim() == 2: |
|
|
ctx = ctx.unsqueeze(0).expand(len(self.prompts_lens[cls]), -1, -1) |
|
|
|
|
|
if position == "end": |
|
|
prompts = torch.cat( |
|
|
[ |
|
|
prefix, |
|
|
ctx, |
|
|
suffix, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
elif position == "middle": |
|
|
|
|
|
half_n_ctx = self.n_ctx // 2 |
|
|
prompts = [] |
|
|
|
|
|
for i in range(len(self.prompts_lens[cls])): |
|
|
p_len = self.prompts_lens[cls][i] |
|
|
|
|
|
prefix_i = prefix[i : i + 1, :, :] |
|
|
class_i = suffix[i : i + 1, :p_len, :] |
|
|
suffix_i = suffix[i : i + 1, p_len:, :] |
|
|
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] |
|
|
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] |
|
|
|
|
|
prompt = torch.cat( |
|
|
[ |
|
|
prefix_i, |
|
|
ctx_i_half1, |
|
|
class_i, |
|
|
ctx_i_half2, |
|
|
suffix_i, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
prompts.append(prompt) |
|
|
prompts = torch.cat(prompts, dim=0) |
|
|
|
|
|
else : |
|
|
assert position == "front" |
|
|
prompts = [] |
|
|
|
|
|
for i in range(len(self.prompts_lens[cls])): |
|
|
p_len = self.prompts_lens[cls][i] |
|
|
|
|
|
prefix_i = prefix[i : i + 1, :, :] |
|
|
class_i = suffix[i : i + 1, :p_len, :] |
|
|
suffix_i = suffix[i : i + 1, p_len:, :] |
|
|
ctx_i = ctx[i : i + 1, :, :] |
|
|
prompt = torch.cat( |
|
|
[ |
|
|
prefix_i, |
|
|
class_i, |
|
|
ctx_i, |
|
|
suffix_i, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
prompts.append(prompt) |
|
|
|
|
|
prompts = torch.cat(prompts, dim=0) |
|
|
|
|
|
cls_prompts[cls].append(prompts) |
|
|
cls_prompts[cls]=torch.cat(cls_prompts[cls],dim=0) |
|
|
return cls_prompts |
|
|
|
|
|
|
|
|
class PromptMaker(nn.Module, |
|
|
PyTorchModelHubMixin, |
|
|
repo_url="https://github.com/Continual-Mega/Continual-Mega", |
|
|
paper_url="https://arxiv.org/abs/2506.00956"): |
|
|
|
|
|
def __init__(self, |
|
|
prompts, |
|
|
clip_model, |
|
|
n_ctx: int=8, |
|
|
CSC: bool= True, |
|
|
class_token_position: list=['end'], |
|
|
): |
|
|
|
|
|
super().__init__() |
|
|
assert 'normal' in prompts and 'abnormal' in prompts |
|
|
|
|
|
for position in class_token_position: |
|
|
assert position in ['end','middle','front'] |
|
|
|
|
|
self.prompt_learner = PromptLearner(prompts, n_ctx, CSC, class_token_position, clip_model) |
|
|
self.tokenized_prompts = self.prompt_learner.tokenized_prompts |
|
|
|
|
|
self.class_token_position = class_token_position |
|
|
self.text_encoder = TextEncoder(clip_model) |
|
|
|
|
|
def forward(self): |
|
|
prompts = self.prompt_learner() |
|
|
tokenized_prompts = self.tokenized_prompts |
|
|
text_features=[] |
|
|
|
|
|
for cls in prompts: |
|
|
class_embedding = self.text_encoder(prompts[cls], tokenized_prompts[cls].repeat(len(self.class_token_position),1)) |
|
|
class_embedding = class_embedding.mean(dim=0) |
|
|
class_embedding = class_embedding / class_embedding.norm() |
|
|
text_features.append(class_embedding) |
|
|
text_features = torch.stack(text_features, dim=1) |
|
|
|
|
|
return text_features |