.ipynb_checkpoints/EmotionCLIP-checkpoint.py DELETED
@@ -1,151 +0,0 @@
1
- """
2
- VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱
3
-
4
- 文本生成仍然考虑因果掩码。
5
- """
6
- import torch.nn.functional as F
7
- from VIT import model as VIT
8
- from Text_Encoder import text_encoder as transformer
9
- import torch.nn as nn
10
- import torch
11
- from Text_Encoder import MLP
12
-
13
- class Prompt_block(nn.Module):
14
- def __init__(self,config):
15
- super(Prompt_block,self).__init__()
16
- self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
17
- def forward(self,text_embeddings):
18
- b,_,_=text_embeddings.size()
19
- n,dim=self.prompt_embedding.weight.size()
20
- """
21
- new_embeddings=[]
22
- for batch,index_ in enumerate(index):
23
- text_embedding=text_embeddings[0]
24
- text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
25
- new_embeddings.append(text_embedding)
26
- stacked_embedding= torch.stack(new_embeddings, dim=0)
27
- return stacked_embedding
28
- """
29
- text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
30
- return text_embeddings
31
-
32
-
33
-
34
-
35
-
36
- class CLIP(nn.Module):
37
- def __init__(self,config):
38
- super().__init__()
39
- self.visual=VIT
40
- self.device=config.device
41
- self.dtype=config.dtype
42
- self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
43
- self.max_position_embeddings=config.max_position_embeddings
44
- self.prompt_num=config.prompt_num
45
- self.transformer=transformer
46
- #增加一个prompt block
47
- self.prompt_block=Prompt_block(config)
48
- self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
49
- self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
50
- self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
51
- self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
52
- def encode_image(self,img,use_emotion=True):
53
- cls_embedding=self.visual(img,use_emotion)
54
- #cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
55
- return cls_embedding
56
- def encode_text(self,text,use_emotion=True):
57
- #预留20token的位置
58
- b,n=text.size()
59
- index=text.argmax(dim=-1)
60
- text_embedding=self.token_embedding(text)
61
- #text_embedding=self.prompt_block(index,text_embedding)
62
- if n==self.max_position_embeddings-self.prompt_num:
63
- text_embedding=self.prompt_block(text_embedding)
64
- index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
65
- position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
66
- text_embedding=position_embedding+text_embedding
67
- text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
68
- text_embedding=self.ln_final(text_embedding)
69
- #传入的标记有
70
- #print(index[0],index_new[0],text_embedding.shape)
71
- text_embedding=text_embedding[torch.arange(text.shape[0]),index]
72
- text_embedding=text_embedding@self.text_projection.to(self.dtype)
73
-
74
- return text_embedding
75
-
76
- def forward(self,image,text,use_emotion=True):
77
- image_features=self.encode_image(image,use_emotion)
78
- text_features=self.encode_text(text,use_emotion)
79
- # normalized features
80
- image_features=image_features/image_features.norm(dim=-1,keepdim=True)
81
- text_features=text_features/text_features.norm(dim=-1,keepdim=True)
82
- # cosine similarity as logits
83
- logit_scale=self.logit_scale.exp()
84
- logits_per_image=logit_scale*image_features@text_features.t()
85
- logits_per_text=logits_per_image.t()
86
- # shape = [global_batch_size, global_batch_size]
87
- return logits_per_image,logits_per_text
88
-
89
- class Config:
90
- def __init__(self):
91
- self.vocab_size=49408
92
- self.image_dim=768
93
- self.num_patches=49
94
- self.patch_size=32
95
- self.hidden_size=512
96
- self.prompt_num=20
97
- self.max_position_embeddings=77
98
- self.num_hidden_layers=12
99
- self.num_attention_heads=8
100
- self.head_size=64
101
- self.layer_norm_eps=1e-5
102
- self.activation_function="Quickgelu"
103
- self.dtype=torch.float16
104
- self.device=torch.device("cuda:0")
105
- self.logit_scale_init=4.6052
106
- self.num_virtual_tokens=20
107
- self.token_dim=self.hidden_size
108
- self.encoder_hidden_size=self.hidden_size
109
-
110
- config=Config()
111
- model=CLIP(config)
112
- #加载预训练权重
113
- model.load_state_dict(torch.load(r'/root/autodl-tmp/true_Emoset/EmotionCLIP_v2.bin',weights_only=True,map_location='cpu'),strict=True)
114
- """
115
- for name, param in model.named_parameters():
116
- if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix'
117
- print(name,"'s requires_grad turn off.")
118
- param.requires_grad = False # 冻结该参数
119
- else:
120
- print(name,"'s requires_grad turn on.")
121
- param.requires_grad = True # 允许该参数进行训练
122
- """
123
-
124
- #编译模型
125
- #model=torch.compile(model)
126
- import pickle
127
- from PIL import Image
128
- import clip
129
- with open('./preprocess.pkl','rb') as f:
130
- preprocess = pickle.load(f)
131
- with open('./tokenize.pkl','rb') as f:
132
- tokenizer=pickle.load(f)
133
- device=config.device
134
- image = preprocess(Image.open("spider.jpg")).unsqueeze(0).to(device)
135
- text = tokenizer(["This picture conveys a sense of fear", "This picture conveys a sense of contentment", "This picture conveys a sense of anger","This picture conveys a sense of sadness","This picture conveys a sense of neutral","This picture conveys a sense of disgust","This picture conveys a sense of excitement","This picture conveys a sense of awe","This picture conveys a sense of amusement"],context_length=57).to(device)
136
- #context_length=57
137
- with torch.no_grad():
138
- logits_per_image, logits_per_text = model(image.to(config.dtype), text)
139
- probs = logits_per_image.softmax(dim=-1).cpu().numpy()
140
- print("情感识别:",probs)
141
- #保存合并前缀的权重
142
- import torch
143
- torch.save(model.state_dict(),'./upload/EmotionCLIP-V2.pth')
144
- #泛化性能
145
- """
146
- text=tokenizer(['This is a spider.','This is a dog','This is a cat'],context_length=57).to(device)
147
- with torch.no_grad():
148
- logits_per_image, logits_per_text = model(image.to(config.dtype), text,use_emotion=False)
149
- probs = logits_per_image.softmax(dim=-1).cpu().numpy()
150
- print("泛化识别:",probs)
151
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/Text_Encoder-checkpoint.py DELETED
@@ -1,192 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import math
4
- from torch.nn.attention import SDPBackend, sdpa_kernel
5
- from torch.nn import functional as F
6
-
7
-
8
- class PrefixEncoder(torch.nn.Module):
9
- def __init__(self,config):
10
- super(PrefixEncoder,self).__init__()
11
- self.config=config
12
- self.device=config.device
13
- self.dtype=config.dtype
14
- self.num_virtual_tokens=config.num_virtual_tokens
15
- self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
16
- self.token_dim=config.token_dim
17
- self.encoder_hidden_size=config.encoder_hidden_size
18
- self.num_layers=config.num_layers
19
- self.transformer=torch.nn.Sequential(
20
- torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
21
- torch.nn.Tanh(),
22
- torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
23
- )
24
- def forward(self,input_ids,batch_size):
25
- input_ids=input_ids.unsqueeze(0)
26
- prefix_embedding=self.embedding(input_ids)
27
- prefix_embedding=self.transformer(prefix_embedding)
28
- self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
29
- prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
30
- prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
31
- prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
32
- del self.embedding
33
- del self.transformer
34
- k,v=prefix_embedding.chunk(2,dim=0)
35
- return (k.squeeze(0),v.squeeze(0))
36
-
37
-
38
- class Transformer(nn.Module):
39
- def __init__(self,config):
40
- super(Transformer,self).__init__()
41
- self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
42
- self.prefix=PrefixEncoder(config)
43
- prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
44
- self.register_buffer("prefix_tokens",prefix_tokens)
45
- def forward(self,hidden_state,use_emotion):
46
- if use_emotion:
47
- #print("激活text transformer prefix.")
48
- b,n,h=hidden_state.shape
49
- prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
50
- for index,resblock in enumerate(self.resblocks):
51
- hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
52
- return hidden_state
53
- else:
54
- for index,resblock in enumerate(self.resblocks):
55
- hidden_state=resblock(hidden_state)
56
- return hidden_state
57
-
58
-
59
-
60
-
61
-
62
-
63
- class ResidualAttentionBlock(nn.Module):
64
- def __init__(self,config):
65
- super(ResidualAttentionBlock,self).__init__()
66
- self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
67
- self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
68
- #self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
69
- self.attn=MultiHeadAttention(config)
70
- self.mlp=MLP(config)
71
- def forward(self,hidden_state,prefix_k=None,prefix_v=None):
72
- residual=hidden_state
73
- hidden_state=self.ln_1(hidden_state)
74
- hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
75
- hidden_state=residual+hidden_state
76
- residual=hidden_state
77
- hidden_state=self.ln_2(hidden_state)
78
- hidden_state=self.mlp(hidden_state)
79
- hidden_state=residual+hidden_state
80
- return hidden_state
81
-
82
- class MultiHeadAttention(nn.Module):
83
- def __init__(self,config):
84
- super(MultiHeadAttention,self).__init__()
85
- self.hidden_size=config.hidden_size
86
- self.num_heads=config.num_heads
87
- self.head_size=self.hidden_size//self.num_heads
88
- #nn.Parameter包含weight和bias可训练参数
89
- self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
90
- self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
91
- #self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
92
- #self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
93
- #self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
94
- self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
95
- def forward(self,hidden_state,prefix_k=None,prefix_v=None):
96
- b,n,c=hidden_state.shape
97
- #q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
98
- #k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
99
- #v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
100
- q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
101
- if prefix_k is not None and prefix_v is not None:
102
- #将前缀插入到序列之前
103
- k=torch.cat((prefix_k,k),dim=1)
104
- v=torch.cat((prefix_v,v),dim=1)
105
- #print("model origin k :",k[:,0,0])
106
- bk,nk,hk=k.shape
107
- bq,nq,hq=q.shape
108
- q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
109
- k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
110
- v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
111
- attention_logits=F.scaled_dot_product_attention(q, k, v)
112
- attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
113
- attention_output=self.out_proj(attention_logits)
114
- return attention_output
115
-
116
-
117
- class GELU(nn.Module):
118
- """
119
- 误差函数erf:
120
- erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
121
- 其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一个点,具体来说:
122
- x是误差函数的输入参数,表示积分的上限
123
- t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
124
- e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
125
- 通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
126
- """
127
- def forward(self,x):
128
- return 0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))
129
-
130
- class QuickGELU(nn.Module):
131
- def __init__(self):
132
- super(QuickGELU,self).__init__()
133
- def forward(self,x):
134
- old_dtype=x.dtype
135
- x=x.to(torch.float32)
136
- return (x*torch.sigmoid(1.702*x)).to(old_dtype)
137
-
138
-
139
- class MLP(nn.Module):
140
- def __init__(self,config):
141
- super(MLP,self).__init__()
142
- self.hidden_size=config.hidden_size
143
- self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
144
- self.gelu=QuickGELU()
145
- self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
146
- def forward(self,hidden_state):
147
- hidden_state=self.c_fc(hidden_state)
148
- hidden_state=self.gelu(hidden_state)
149
- hidden_state=self.c_proj(hidden_state)
150
- return hidden_state
151
-
152
- class Config:
153
- def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
154
- self.vocab_size=vocab_size
155
- self.max_position_embeddings=max_position_embeddings
156
- self.hidden_size=hidden_size
157
- self.num_layers=num_layers
158
- self.num_heads=num_heads
159
- self.device=device
160
- self.dtype=dtype
161
- self.norm_eps=1e-5
162
- self.num_virtual_tokens=20
163
- self.token_dim=hidden_size
164
- self.encoder_hidden_size=hidden_size
165
- config=Config(
166
- vocab_size=49408,
167
- max_position_embeddings=77,
168
- hidden_size=512,
169
- num_layers=12,
170
- num_heads=8,
171
- device=torch.device('cuda:0'),
172
- dtype=torch.float16
173
- )
174
- class TextEncoder(nn.Module):
175
- def __init__(self,config):
176
- super(TextEncoder,self).__init__()
177
- self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,device=config.device,dtype=config.dtype)
178
- self.positional_embedding=nn.Parameter(torch.zeros(config.max_position_embeddings,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
179
- self.transformer=Transformer(config)
180
- self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
181
- def forward(self,input_ids):
182
- b,n=input_ids.shape
183
- prompt_embedding,token_embeddings=self.token_embedding(input_ids)
184
- position_ids=torch.arange(n,device=config.device,dtype=config.dtype).unsqueeze(0).expand(b,n)
185
- position_embeddings=self.positional_embedding[position_ids]
186
- embeddings=token_embeddings+position_embeddings
187
- embeddings=torch.cat((prompt_embedding,embeddings),dim=1)
188
- embeddings=self.transformer(embeddings)
189
- embeddings=self.ln_final(embeddings)
190
- return embeddings
191
-
192
- text_encoder=Transformer(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/VIT-checkpoint.py DELETED
@@ -1,243 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- import math
6
- import os
7
- import sys
8
-
9
- #huggingface实现的前缀微调
10
- class PrefixEncoder(torch.nn.Module):
11
- def __init__(self,config):
12
- super(PrefixEncoder,self).__init__()
13
- self.config=config
14
- self.device=config.device
15
- self.dtype=config.dtype
16
- self.num_virtual_tokens=config.num_virtual_tokens
17
- self.embedding=torch.nn.Embedding(config.num_virtual_tokens,config.token_dim,device=config.device,dtype=config.dtype)
18
- self.token_dim=config.token_dim
19
- self.encoder_hidden_size=config.encoder_hidden_size
20
- self.num_layers=config.num_layers
21
- self.transformer=torch.nn.Sequential(
22
- torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
23
- torch.nn.Tanh(),
24
- torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
25
- )
26
- def forward(self,input_ids,batch_size):
27
- input_ids=input_ids.unsqueeze(0)
28
- prefix_embedding=self.embedding(input_ids)
29
- prefix_embedding=self.transformer(prefix_embedding)
30
- self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
31
- prefix_embedding=prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.num_layers*2*self.token_dim)
32
- prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
33
- prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
34
- del self.embedding
35
- del self.transformer
36
- k,v=prefix_embedding.chunk(2,dim=0)
37
- return (k.squeeze(0),v.squeeze(0))
38
-
39
-
40
- import torch
41
- import torch.nn as nn
42
- import math
43
- from torch.nn.attention import SDPBackend, sdpa_kernel
44
- from torch.nn import functional as F
45
- def position_embedding(x,position_ids):
46
- hidden_size=x.size(2)
47
- seq_len=x.size(1)
48
- div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
49
- positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
50
- positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
51
- positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
52
- positional_encoding=positional_encoding.unsqueeze(0)
53
- return positional_encoding
54
-
55
-
56
-
57
-
58
- class VisionTransformer(nn.Module):
59
- def __init__(self,config):
60
- super(VisionTransformer,self).__init__()
61
- self.image_channel=config.image_channel
62
- self.hidden_size=config.hidden_size
63
- self.norm_eps=config.norm_eps
64
- self.patch_size=config.patch_size
65
- self.output_dim=config.output_dim
66
- self.dtype=config.dtype
67
- self.num_patches=config.num_patches
68
- self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
69
- self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
70
- self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
71
- self.transformer=Transformer(config)
72
- #self.position_ids=torch.arange(config.num_patches+1,dtype=torch.long,device=config.device)
73
- #self.position_embeddings=nn.Parameter(torch.zeros(1,config.num_patches+1,config.hidden_size))
74
- #nn.init.normal_(self.position_embeddings)
75
- #clsToken,用于图像分类任务
76
- #self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size,device=config.device))
77
- #分类token不是可训练参数
78
- self.class_embedding=nn.Parameter(torch.empty(config.hidden_size,device=config.device),requires_grad=False)
79
- #很明显这里的position_embedding也是一个可学习参数
80
- self.positional_embedding=nn.Parameter(torch.empty(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=False)
81
- #可训练参数
82
- self.proj=nn.Parameter(torch.empty(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=False)
83
- self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
84
- def forward(self,hidden_state,use_emotion):
85
- b,c,h,w=hidden_state.shape
86
- #获得embedding向量
87
- hidden_state=self.conv1(hidden_state)
88
- hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
89
- #添加cls token embedding
90
- hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
91
- #使用transformer原论文中的固定位置嵌入
92
- #hidden_state=hidden_state+position_embedding(hidden_state,self.position_ids)
93
- hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
94
- hidden_state=self.ln_pre(hidden_state)
95
- hidden_state=self.transformer(hidden_state,use_emotion)
96
- #提取cls token输出 与image patch输出
97
- cls_state=hidden_state[:,0,:]
98
- cls_state=self.ln_post(cls_state)
99
- cls_state=torch.matmul(cls_state,self.proj)
100
- #image_state=hidden_state[:,1:,:]
101
- #image_state size (batch_size,49,768)
102
- return cls_state
103
-
104
- class Transformer(nn.Module):
105
- def __init__(self,config):
106
- super(Transformer,self).__init__()
107
- self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
108
- self.prefix=PrefixEncoder(config)
109
- prefix_tokens=torch.arange(0,config.num_virtual_tokens,device=config.device,dtype=torch.long)
110
- self.register_buffer("prefix_tokens",prefix_tokens)
111
- def forward(self,hidden_state,use_emotion):
112
- if use_emotion:
113
- b,n,h=hidden_state.shape
114
- prefix_k,prefix_v=self.prefix(self.prefix_tokens,b)
115
- for index,resblock in enumerate(self.resblocks):
116
- #在每一层之前提取前缀向量输入到resblock中进行拼接
117
- hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
118
- return hidden_state
119
- else:
120
- for index,resblock in enumerate(self.resblocks):
121
- #在每一层之前提取前缀向量输入到resblock中进行拼接
122
- hidden_state=resblock(hidden_state)
123
- return hidden_state
124
-
125
-
126
-
127
-
128
-
129
-
130
- class ResidualAttentionBlock(nn.Module):
131
- def __init__(self,config):
132
- super(ResidualAttentionBlock,self).__init__()
133
- self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
134
- self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
135
- #self.attn=nn.MultiheadAttention(config.hidden_size,config.num_heads,device=config.device,dtype=config.dtype)
136
- self.attn=MultiHeadAttention(config)
137
- self.mlp=MLP(config)
138
- def forward(self,hidden_state,prefix_k=None,prefix_v=None):
139
- residual=hidden_state
140
- hidden_state=self.ln_1(hidden_state)
141
- hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
142
- hidden_state=residual+hidden_state
143
- residual=hidden_state
144
- hidden_state=self.ln_2(hidden_state)
145
- hidden_state=self.mlp(hidden_state)
146
- hidden_state=residual+hidden_state
147
- return hidden_state
148
-
149
- class MultiHeadAttention(nn.Module):
150
- def __init__(self,config):
151
- super(MultiHeadAttention,self).__init__()
152
- self.hidden_size=config.hidden_size
153
- self.num_heads=config.num_heads
154
- self.head_size=self.hidden_size//self.num_heads
155
- #nn.Parameter包含weight和bias可训练参数
156
- self.in_proj_weight=nn.Parameter(torch.empty(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
157
- self.in_proj_bias=nn.Parameter(torch.empty(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=False)
158
- #self.q_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
159
- #self.k_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
160
- #self.v_linear=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device)
161
- self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
162
- def forward(self,hidden_state,prefix_k=None,prefix_v=None):
163
- b,n,h=hidden_state.shape
164
- #q=self.q_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
165
- #k=self.k_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,3,1)
166
- #v=self.v_linear(hidden_state).view(b,n,self.num_heads,self.head_size).permute(0,2,1,3)
167
- q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
168
- if prefix_k is not None and prefix_v is not None:
169
- #将前缀插入到序列之前
170
- #print("origional k.shape",prefix_k.shape)
171
- k=torch.cat((prefix_k,k),dim=1)
172
- v=torch.cat((prefix_v,v),dim=1)
173
- #print("model original k :",k[:,0,0])
174
- bk,nk,hk=k.shape
175
- bq,nq,hq=q.shape
176
- q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
177
- k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
178
- v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
179
- attention_logits=F.scaled_dot_product_attention(q, k, v)
180
- attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
181
- attention_output=self.out_proj(attention_logits)
182
- return attention_output
183
-
184
-
185
-
186
- class GELU(nn.Module):
187
- """
188
- 误差函数erf:
189
- erf(x)=2/sqrt(pi)*integral(exp(-t^2),t=0,x)
190
- 其中t是一个虚拟变量,用于表示从0到x的积分范围内的每一���点,具体来说:
191
- x是误差函数的输入参数,表示积分的上限
192
- t是积分变量,它从0变化到x,在每个点上计算e-t^2的值
193
- e-t^2是被积函数,表示每个t点上的高斯分布的概率密度。
194
- 通过积分,误差函数计算了从0到x的高斯分布的概率累积值,具体来说,误差函数的积分部分计算的是区间[0,x]内高斯分布的概率密度的积分
195
- """
196
- def forward(self,x):
197
- old_dtype=x.dtype
198
- x=x.to(torch.float32)
199
- return (0.5*x*(1.0+torch.erf(x/torch.sqrt(2.0)))).to(old_dtype)
200
-
201
- class QuickGELU(nn.Module):
202
- def __init__(self):
203
- super(QuickGELU,self).__init__()
204
- def forward(self,x):
205
- old_dtype=x.dtype
206
- x=x.to(torch.float32)
207
- return (x*torch.sigmoid(1.702*x)).to(old_dtype)
208
-
209
-
210
- class MLP(nn.Module):
211
- def __init__(self,config):
212
- super(MLP,self).__init__()
213
- self.hidden_size=config.hidden_size
214
- self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
215
- self.gelu=QuickGELU()
216
- self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
217
- def forward(self,hidden_state):
218
- hidden_state=self.c_fc(hidden_state)
219
- hidden_state=self.gelu(hidden_state)
220
- hidden_state=self.c_proj(hidden_state)
221
- return hidden_state
222
-
223
-
224
-
225
- class ViTConfig:
226
- def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
227
- self.image_channel=image_channel
228
- self.hidden_size=hidden_size
229
- self.num_heads=num_heads
230
- self.num_layers=num_layers
231
- self.patch_size=patch_size
232
- self.num_patches=num_patches
233
- self.norm_eps=norm_eps
234
- self.device=device
235
- self.dtype=torch.float16
236
- self.patch_token_num=self.hidden_size//self.patch_size**2+1
237
- self.output_dim=output_dim
238
- self.num_virtual_tokens=20
239
- self.token_dim=self.hidden_size
240
- self.encoder_hidden_size=self.hidden_size
241
-
242
- config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
243
- model=VisionTransformer(config)