ADCT / CoOp.py
Continual-Mega's picture
Upload CoOp.py with huggingface_hub
940092e verified
import os.path as osp
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from CLIP.tokenizer import SimpleTokenizer,tokenize
from huggingface_hub import PyTorchModelHubMixin
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x,_,_ = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class PromptLearner(nn.Module):
def __init__(self,
prompts,
n_ctx, # prompt max len
CSC, # True or False multi prompt
class_token_position, # cls position
clip_model):
super().__init__()
ctx_dim = clip_model.ln_final.weight.shape[0] #
self.ctx={}
for cls in prompts:
for position in class_token_position:
if CSC:
ctx_vectors = torch.empty(len(prompts[cls]), n_ctx, ctx_dim).to(clip_model.device)
else:
ctx_vectors = torch.empty(n_ctx, ctx_dim).to(clip_model.device)
nn.init.normal_(ctx_vectors, std=0.02)
self.ctx['{}_{}'.format(cls,position)]=nn.Parameter(ctx_vectors,requires_grad=True)
self.ctx = nn.ParameterDict(self.ctx) # to be optimized
prompt_prefix = " ".join(["X"] * n_ctx)
_tokenizer = SimpleTokenizer()
prompts_split={cls: [prompt.replace("_", " ") for prompt in prompts[cls]] for cls in prompts}
prompts_lens= {cls: [ len(_tokenizer.encode(prompt)) for prompt in prompts_split[cls]] for cls in prompts_split}
prompts_learnable_tokens = {cls:[prompt_prefix + " " + prompt + "." for prompt in prompts_split[cls]] for cls in prompts_split}
tokenized_prompts = {cls:torch.cat([tokenize(prompt) for prompt in prompts_learnable_tokens[cls]]).to(clip_model.device) for cls in prompts_learnable_tokens}
with torch.no_grad():
embeddings = {cls:clip_model.token_embedding(tokenized_prompts[cls]) for cls in tokenized_prompts}
self.register_embeddings={}
for cls in embeddings:
self.register_embeddings['{}_token_prefix'.format(cls)]=embeddings[cls][:, :1, :]
self.register_embeddings['{}_token_suffix'.format(cls)]=embeddings[cls][:, 1 + n_ctx :, :]
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts
self.prompts_lens = prompts_lens
self.class_token_position = class_token_position
def forward(self):
cls_prompts={}
for cls in self.tokenized_prompts:
prefix = self.register_embeddings['{}_token_prefix'.format(cls)]
suffix = self.register_embeddings['{}_token_suffix'.format(cls)]
cls_prompts[cls]=[]
for position in self.class_token_position:
ctx = self.ctx['{}_{}'.format(cls,position)]
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(len(self.prompts_lens[cls]), -1, -1)
if position == "end":
prompts = torch.cat(
[
prefix, # (n_cls, 1, dim)
ctx, # (n_cls, n_ctx, dim)
suffix, # (n_cls, *, dim)
],
dim=1,
)
elif position == "middle":
half_n_ctx = self.n_ctx // 2
prompts = []
for i in range(len(self.prompts_lens[cls])):
p_len = self.prompts_lens[cls][i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :p_len, :]
suffix_i = suffix[i : i + 1, p_len:, :]
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
else :
assert position == "front"
prompts = []
for i in range(len(self.prompts_lens[cls])):
p_len = self.prompts_lens[cls][i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :p_len, :]
suffix_i = suffix[i : i + 1, p_len:, :]
ctx_i = ctx[i : i + 1, :, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
class_i, # (1, name_len, dim)
ctx_i, # (1, n_ctx, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
cls_prompts[cls].append(prompts)
cls_prompts[cls]=torch.cat(cls_prompts[cls],dim=0)
return cls_prompts
class PromptMaker(nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/Continual-Mega/Continual-Mega",
paper_url="https://arxiv.org/abs/2506.00956"):
def __init__(self,
prompts,
clip_model,
n_ctx: int=8, # prompt max len
CSC: bool= True, # True or False multi prompt
class_token_position: list=['end'], # cls position
):
super().__init__()
assert 'normal' in prompts and 'abnormal' in prompts
for position in class_token_position:
assert position in ['end','middle','front']
self.prompt_learner = PromptLearner(prompts, n_ctx, CSC, class_token_position, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.class_token_position = class_token_position
self.text_encoder = TextEncoder(clip_model)
def forward(self):
prompts = self.prompt_learner()
tokenized_prompts = self.tokenized_prompts
text_features=[]
for cls in prompts:
class_embedding = self.text_encoder(prompts[cls], tokenized_prompts[cls].repeat(len(self.class_token_position),1))
class_embedding = class_embedding.mean(dim=0)
class_embedding = class_embedding / class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1)
return text_features