| | """ |
| | VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱 |
| | |
| | 文本生成仍然考虑因果掩码。 |
| | """ |
| | import torch.nn.functional as F |
| | from VIT import model as VIT |
| | from Text_Encoder import text_encoder as transformer |
| | import torch.nn as nn |
| | import torch |
| | from Text_Encoder import MLP |
| |
|
| | class Prompt_block(nn.Module): |
| | def __init__(self,config): |
| | super(Prompt_block,self).__init__() |
| | self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device) |
| | def forward(self,text_embeddings): |
| | b,_,_=text_embeddings.size() |
| | n,dim=self.prompt_embedding.weight.size() |
| | """ |
| | new_embeddings=[] |
| | for batch,index_ in enumerate(index): |
| | text_embedding=text_embeddings[0] |
| | text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0) |
| | new_embeddings.append(text_embedding) |
| | stacked_embedding= torch.stack(new_embeddings, dim=0) |
| | return stacked_embedding |
| | """ |
| | text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1) |
| | return text_embeddings |
| | |
| | |
| | |
| | |
| |
|
| | class CLIP(nn.Module): |
| | def __init__(self,config): |
| | super().__init__() |
| | self.visual=VIT |
| | self.device=config.device |
| | self.dtype=config.dtype |
| | self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device) |
| | self.max_position_embeddings=config.max_position_embeddings |
| | self.prompt_num=config.prompt_num |
| | self.transformer=transformer |
| | |
| | self.prompt_block=Prompt_block(config) |
| | self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device)) |
| | self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device) |
| | self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device)) |
| | self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False) |
| | def encode_image(self,img,use_emotion=True): |
| | cls_embedding=self.visual(img,use_emotion) |
| | |
| | return cls_embedding |
| | def encode_text(self,text,use_emotion=True): |
| | |
| | b,n=text.size() |
| | index=text.argmax(dim=-1) |
| | text_embedding=self.token_embedding(text) |
| | |
| | if n==self.max_position_embeddings-self.prompt_num: |
| | text_embedding=self.prompt_block(text_embedding) |
| | index=index+torch.tensor(20,device=index.device,dtype=index.dtype) |
| | position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype) |
| | text_embedding=position_embedding+text_embedding |
| | text_embedding=self.transformer(text_embedding,use_emotion=use_emotion) |
| | text_embedding=self.ln_final(text_embedding) |
| | |
| | |
| | text_embedding=text_embedding[torch.arange(text.shape[0]),index] |
| | text_embedding=text_embedding@self.text_projection.to(self.dtype) |
| |
|
| | return text_embedding |
| |
|
| | def forward(self,image,text,use_emotion=True): |
| | image_features=self.encode_image(image,use_emotion) |
| | text_features=self.encode_text(text,use_emotion) |
| | |
| | image_features=image_features/image_features.norm(dim=-1,keepdim=True) |
| | text_features=text_features/text_features.norm(dim=-1,keepdim=True) |
| | |
| | logit_scale=self.logit_scale.exp() |
| | logits_per_image=logit_scale*image_features@text_features.t() |
| | logits_per_text=logits_per_image.t() |
| | |
| | return logits_per_image,logits_per_text |
| |
|
| | class Config: |
| | def __init__(self): |
| | self.vocab_size=49408 |
| | self.image_dim=768 |
| | self.num_patches=49 |
| | self.patch_size=32 |
| | self.hidden_size=512 |
| | self.prompt_num=20 |
| | self.max_position_embeddings=77 |
| | self.num_hidden_layers=12 |
| | self.num_attention_heads=8 |
| | self.head_size=64 |
| | self.layer_norm_eps=1e-5 |
| | self.activation_function="Quickgelu" |
| | self.dtype=torch.float16 |
| | self.device=torch.device("cuda:0") |
| | self.logit_scale_init=4.6052 |
| | self.num_virtual_tokens=20 |
| | self.token_dim=self.hidden_size |
| | self.encoder_hidden_size=self.hidden_size |
| | |
| | config=Config() |
| | model=CLIP(config) |
| | |
| | model.load_state_dict(torch.load(r'./EmotionCLIP-V2.pth',weights_only=True,map_location='cpu'),strict=True) |
| | """ |
| | for name, param in model.named_parameters(): |
| | if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix' |
| | print(name,"'s requires_grad turn off.") |
| | param.requires_grad = False # 冻结该参数 |
| | else: |
| | print(name,"'s requires_grad turn on.") |
| | param.requires_grad = True # 允许该参数进行训练 |
| | """ |
| |
|
| | |
| | |
| | import pickle |
| | from PIL import Image |
| | import numpy as np |
| | import clip |
| | with open('./preprocess.pkl','rb') as f: |
| | preprocess = pickle.load(f) |
| | with open('./tokenize.pkl','rb') as f: |
| | tokenizer=pickle.load(f) |
| | device=config.device |
| | image = preprocess(Image.open("Dog sad.jpg")).unsqueeze(0).to(device) |
| | |
| | labels=[ |
| | 'amusement', |
| | 'anger', |
| | 'awe', |
| | 'contentment', |
| | 'disgust', |
| | 'excitement', |
| | 'fear', |
| | 'sadness', |
| | 'neutral' |
| | ] |
| | text_list=[ f"This picture conveys a sense of {label}" for label in labels] |
| | tokens= tokenizer(text_list, |
| | context_length=57).to(device) |
| |
|
| | with torch.no_grad(): |
| | logits_per_image, logits_per_text = model(image.to(config.dtype), tokens) |
| | probs = logits_per_image.softmax(dim=-1).cpu().numpy() |
| |
|
| | |
| | predicted_index = np.argmax(probs, axis=1) |
| | predicted_label=labels[predicted_index[0]] |
| |
|
| | print("情感识别:", probs) |
| | print("预测的情感标签:", predicted_label) |
| |
|
| | |
| | labels=[ |
| | 'spider', |
| | 'dog', |
| | 'cat', |
| | 'fish' |
| | ] |
| | text_list=[ f"This is a {label}" for label in labels] |
| | tokens= tokenizer(text_list,context_length=57).to(device) |
| |
|
| | with torch.no_grad(): |
| | logits_per_image, logits_per_text = model(image.to(config.dtype), tokens, use_emotion=False) |
| | probs = logits_per_image.softmax(dim=-1).cpu().numpy() |
| |
|
| | |
| | predicted_index = np.argmax(probs, axis=1) |
| | predicted_label=labels[predicted_index[0]] |
| |
|
| | print("泛化识别:", probs) |
| | print("预测的泛化标签:", predicted_label) |
| |
|
| |
|