Airin-chan commited on
Commit
3d2606c
·
verified ·
1 Parent(s): 8df6d84

Upload 2 files

Browse files
Files changed (2) hide show
  1. VLMEncoderLCT.pth +3 -0
  2. lctvlm.py +275 -0
VLMEncoderLCT.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a8391c37c910f14c646fa4de1c567dd1bb111d727fb9636c5ef8f59aa5afe05
3
+ size 9323183
lctvlm.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """LCTVLM.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1ekzgx_mlQEjCXDZlg6qdKd951zUJDBcT
8
+ """
9
+
10
+ import torch
11
+ from torch import nn
12
+ from typing import Optional
13
+ class LCMBlock (nn.Module) :
14
+ """
15
+ LCm (Laten Connected Model ) block, looking attention as two preception and icreasing it
16
+ to N multiple magnitude values.
17
+ """
18
+
19
+ def __init__ (self,d_model :int, drop_rate : float = 0.1) :
20
+ """
21
+ args:
22
+ d_model : int
23
+ dimention of model
24
+
25
+ drop_rate : float
26
+ rate of dropout mechanism
27
+ """
28
+ super().__init__()
29
+ self.step1 = nn.Linear(d_model,d_model)
30
+ self.step2 = nn.Linear(d_model,d_model)
31
+ self.magnitude = nn.Linear(d_model,d_model)
32
+ self.drop = nn.Dropout(drop_rate)
33
+ self.gelu1 = nn.GELU(approximate='tanh')
34
+ self.gelu2 = nn.GELU(approximate='tanh')
35
+ self.tanh = nn.Tanh()
36
+ self.norm = nn.LayerNorm(d_model)
37
+
38
+ def forward(self,x) :
39
+ normx = self.norm(x)
40
+ step1 = self.step1(normx)
41
+ step1 = self.gelu1(step1)
42
+ step2 = self.step2(normx)
43
+ step2 = self.gelu2(step2)
44
+ laten = step1 + step2
45
+ laten = self.magnitude(laten)
46
+ laten = self.tanh(laten)
47
+ return x + laten
48
+
49
+ class PathEmbedding (nn.Module) :
50
+ def __init__ (self,image_size,path_size,embedding_dim) :
51
+ super().__init__()
52
+ self.projection = nn.Conv2d(in_channels=3,out_channels=embedding_dim,kernel_size=path_size,stride=path_size)
53
+ self.n_path = (image_size//path_size)**2
54
+
55
+ def forward(self,x) :
56
+ x = self.projection(x)
57
+ x = x.flatten(2)
58
+ x = x.transpose(1,2)
59
+ return x
60
+
61
+ class PositionalEncoding (nn.Module) :
62
+ def __init__ (self,n_path,embedding_dim) :
63
+ super().__init__()
64
+ self.position = nn.Parameter(torch.normal(mean=0.0,std=0.02,size=(1,n_path + 1,embedding_dim)))
65
+ self.cls_token = nn.Parameter(torch.normal(mean=0.0,std=0.02,size=(1,1,embedding_dim)))
66
+
67
+ def forward(self,x) :
68
+ batch = x.shape[0]
69
+ cls_token = self.cls_token.repeat(batch,1,1)
70
+ x = torch.cat([cls_token,x],dim=1)
71
+ return x + self.position
72
+
73
+ class VisionLCTBlock (nn.Module) :
74
+ def __init__ (self,d_model,drop_rate) :
75
+ super().__init__()
76
+ self.Attention = nn.MultiheadAttention(embed_dim=d_model,num_heads=4,dropout=drop_rate,batch_first=True)
77
+ self.norm = nn.LayerNorm(d_model)
78
+ self.lcmblock = LCMBlock(d_model,drop_rate)
79
+
80
+ def forward(self,x) :
81
+ normx = self.norm(x)
82
+ attention,_ = self.Attention(normx,normx,normx)
83
+ x = x + attention
84
+ x = self.lcmblock(x)
85
+ return x
86
+
87
+ class LMLCTBlock (nn.Module) :
88
+ def __init__ (self,d_model,drop_rate) :
89
+ super().__init__()
90
+ self.attention = nn.MultiheadAttention(embed_dim=d_model,num_heads=4,dropout=drop_rate,batch_first=True)
91
+ self.norm = nn.LayerNorm(d_model)
92
+ self.lcmblock = LCMBlock(d_model,drop_rate)
93
+
94
+ def forward(self,x) :
95
+ S = x.shape[1]
96
+ mask = torch.triu(torch.ones(S,S,device=x.device),diagonal=1).bool()
97
+ normx = self.norm(x)
98
+ attention,_ = self.attention(normx,normx,normx,attn_mask=mask)
99
+ x = x + attention
100
+ x = self.lcmblock(x)
101
+ return x
102
+
103
+ class QFormersBlock (nn.Module) :
104
+ def __init__ (self,d_model,drop_rate) :
105
+ super().__init__()
106
+ self.FFN = nn.Sequential(
107
+ nn.Linear(d_model,d_model*4),
108
+ nn.GELU(),
109
+ nn.Dropout(drop_rate),
110
+ nn.Linear(d_model*4,d_model),
111
+ nn.Dropout(drop_rate)
112
+ )
113
+ self.norm1 = nn.LayerNorm(d_model)
114
+ self.norm2 = nn.LayerNorm(d_model)
115
+ self.attention = nn.MultiheadAttention(d_model,num_heads=4,dropout=drop_rate,batch_first=True)
116
+ self.cross_attn = nn.MultiheadAttention(d_model,num_heads=4,dropout=drop_rate,batch_first=True)
117
+ self.norm2 = nn.LayerNorm(d_model)
118
+ self.norm3 = nn.LayerNorm(d_model)
119
+
120
+
121
+ def forward(self,Query : torch.Tensor,vision_feats : torch.Tensor,attn_mask : Optional[torch.Tensor] = None ) :
122
+ q = Query
123
+ qnorm = self.norm1(q)
124
+ q2,_ = self.attention(qnorm,qnorm,qnorm,attn_mask=attn_mask)
125
+ q = q + q2
126
+
127
+ qnorm2 = self.norm2(q2)
128
+ q3,_ = self.cross_attn(qnorm2,vision_feats,vision_feats)
129
+
130
+ q = q + q3
131
+
132
+ qnorm3 = self.norm3(q)
133
+ ffn = self.FFN(qnorm3)
134
+ q = q + ffn
135
+ return q
136
+
137
+
138
+
139
+ class QFormer (nn.Module) :
140
+ def __init__ (self,dim : int = 768,
141
+ num_query : int = 32,
142
+ depth : int = 6 ,
143
+ num_head : int = 4 ,
144
+ drop_rate : float = 0.1,
145
+ proj_to_lm_dim : Optional[int]=None):
146
+ super().__init__()
147
+ self.dim = dim
148
+ self.num_query = num_query
149
+ self.depth = depth
150
+ self.num_head = num_head
151
+ self.drop_rate = drop_rate
152
+ self.proj_to_lm_dim = proj_to_lm_dim
153
+
154
+ self.query_embed = nn.Parameter(torch.randn(1,num_query,dim))
155
+ self.layers = nn.ModuleList([
156
+ QFormersBlock(dim,drop_rate) for _ in range(depth)
157
+ ])
158
+ self.outnorm = nn.LayerNorm(dim)
159
+ self.proj_to_lm : Optional[nn.Linear] = None
160
+ if proj_to_lm_dim is not None :
161
+ self.proj_to_lm = nn.Linear(dim,proj_to_lm_dim)
162
+
163
+ def forward(self,vision_feats : torch.Tensor,attn_mask : Optional[torch.Tensor] = None) :
164
+ B = vision_feats.shape[0]
165
+ queries = self.query_embed.expand(B,-1,-1).contiguous()
166
+
167
+ for layer in self.layers :
168
+ queries = layer(queries,vision_feats,attn_mask)
169
+
170
+ queries = self.outnorm(queries)
171
+ if self.proj_to_lm is not None :
172
+ queries = self.proj_to_lm(queries)
173
+ return queries
174
+
175
+ def _check_tensor_ok(t: torch.Tensor, name="tensor"):
176
+ if torch.isnan(t).any():
177
+ raise RuntimeError(f"{name} contains NaN values")
178
+ if torch.isinf(t).any():
179
+ raise RuntimeError(f"{name} contains Inf values")
180
+
181
+ # ---------- PretrainedVIT forward (fixed) ----------
182
+ class PretrainedVIT(nn.Module):
183
+ def __init__(self, image_size: int = 224, patch_size: int = 16,
184
+ embdding_dim: int = 256, n_block: int = 4):
185
+ super().__init__()
186
+ self.Pathembedding = PathEmbedding(image_size, patch_size, embdding_dim)
187
+ # Ensure PathEmbedding exposes H and W (recommended). If not, compute:
188
+ # self.patch_H = image_size // patch_size
189
+ # self.patch_W = image_size // patch_size
190
+ self.patch_H = image_size // patch_size
191
+ self.patch_W = image_size // patch_size
192
+
193
+ self.PositionalEncoding = PositionalEncoding(self.Pathembedding.n_path, embdding_dim)
194
+ self.VisionACT = nn.ModuleList([VisionLCTBlock(embdding_dim, 0.15) for _ in range(n_block)])
195
+ self.ffn = nn.Sequential(
196
+ nn.Linear(embdding_dim, embdding_dim * 4),
197
+ nn.GELU(approximate='tanh'),
198
+ nn.Linear(embdding_dim * 4, embdding_dim)
199
+ )
200
+
201
+ self.upsampling = nn.Sequential(
202
+ nn.ConvTranspose2d(embdding_dim, embdding_dim // 2, kernel_size=2, stride=2),
203
+ nn.GELU(approximate='tanh'),
204
+ nn.ConvTranspose2d(embdding_dim // 2, embdding_dim // 4, kernel_size=2, stride=2),
205
+ nn.GELU(approximate='tanh'),
206
+ nn.ConvTranspose2d(embdding_dim //4, 3, kernel_size=2, stride=2)
207
+ )
208
+
209
+ def forward(self, x):
210
+ """
211
+ x: [B, 3, H_image, W_image]
212
+ safe steps:
213
+ - call PathEmbedding -> tokens (B, N_tokens, C)
214
+ - positional encoding might add cls token (so tokens length = n_path or n_path+1)
215
+ - handle both cases robustly
216
+ - reshape using stored patch_H/patch_W (not sqrt(N))
217
+ """
218
+ # 1) patch embedding -> token sequence
219
+ tokens = self.Pathembedding(x) # expected shape (B, N or N+1, C)
220
+ # 2) positional + blocks
221
+ tokens = self.PositionalEncoding(tokens)
222
+ for blk in self.VisionACT:
223
+ tokens = blk(tokens)
224
+ tokens = self.ffn(tokens)
225
+
226
+ # 3) handle cls token robustly
227
+ # expected number of patch tokens:
228
+ expected_n = self.patch_H * self.patch_W
229
+ B, N_all, C = tokens.shape
230
+
231
+ if N_all == expected_n + 1:
232
+ # there is a cls token at index 0 (consistent with PositionalEncoding)
233
+ token_patches = tokens[:, 1:, :] # (B, expected_n, C)
234
+ elif N_all == expected_n:
235
+ token_patches = tokens # already just patches
236
+ else:
237
+ # informative error instead of silent crash
238
+ raise RuntimeError(
239
+ f"Unexpected token length: got N_all={N_all}, expected {expected_n} or {expected_n+1}. "
240
+ "Check PathEmbedding / PositionalEncoding outputs."
241
+ )
242
+
243
+ # 4) reshape to [B, C, patch_H, patch_W] using stored dims
244
+ # ensure correct ordering: tokens are (B, N, C) where N = patch_H * patch_W
245
+ token_patches = token_patches.contiguous()
246
+ try:
247
+ token_map = token_patches.view(B, self.patch_H, self.patch_W, C).permute(0, 3, 1, 2)
248
+ except Exception as e:
249
+ # more informative if reshape fails
250
+ raise RuntimeError(f"Reshape to grid failed: B={B}, N={token_patches.shape[1]}, C={C}, "
251
+ f"patch_H={self.patch_H}, patch_W={self.patch_W}. Error: {e}")
252
+
253
+ # 5) safety checks to avoid silent CUDA asserts
254
+ _check_tensor_ok(token_map, name="token_map before upsampling")
255
+
256
+ # 6) decoder / upsampling to full image
257
+ out = self.upsampling(token_map)
258
+
259
+ # optional clamp / tanh mapping depending on your training scale (0..1 or -1..1)
260
+ # out = out.clamp(0., 1.) # only if your training uses 0..1 images
261
+
262
+ _check_tensor_ok(out, name="output image (after upsampling)")
263
+
264
+ return out
265
+
266
+ def noicing_image (image : torch.Tensor,time_steps,b_start = 1e-3,b_end=0.07,T=50) :
267
+ beta = torch.linspace(b_start,b_end,T,device=image.device)
268
+ alpha = 1 - beta
269
+ alpha_bar = torch.cumprod(alpha,dim=0)
270
+
271
+ step1 = torch.sqrt(alpha_bar[time_steps]).view(-1,1,1,1)
272
+ step2 = torch.sqrt(1 - alpha_bar[time_steps]).view(-1,1,1,1)
273
+ image_noised = step1 * image + step2 * torch.randn_like(image)
274
+ return image_noised
275
+