Update visual.py
Browse files
visual.py
CHANGED
|
@@ -25,13 +25,11 @@ def sliding_window(matrix, window_size, stride):
|
|
| 25 |
window_cols = (width - window_size[1]) // stride + 1
|
| 26 |
images_448 = F.interpolate(matrix, size=window_size, mode='bicubic')
|
| 27 |
windows = []
|
| 28 |
-
# pdb.set_trace()
|
| 29 |
for i in range(window_rows):
|
| 30 |
windows_col = []
|
| 31 |
for j in range(window_cols):
|
| 32 |
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
|
| 33 |
windows.append(window)
|
| 34 |
-
# windows.append(windows_col)
|
| 35 |
windows.append(images_448)
|
| 36 |
images = torch.cat(windows,dim=1)
|
| 37 |
images = images.reshape(b*5,c,window_size[0], window_size[0])
|
|
@@ -145,12 +143,9 @@ class Resampler(nn.Module):
|
|
| 145 |
self.ln_kv = norm_layer(embed_dim)
|
| 146 |
|
| 147 |
self.apply(self._init_weights)
|
| 148 |
-
|
| 149 |
-
#self.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
|
| 150 |
|
| 151 |
def _init_weights(self, m):
|
| 152 |
-
# self.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
|
| 153 |
-
#pdb.set_trace()
|
| 154 |
if isinstance(m, nn.Linear):
|
| 155 |
trunc_normal_(m.weight, std=.02)
|
| 156 |
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
@@ -160,7 +155,6 @@ class Resampler(nn.Module):
|
|
| 160 |
nn.init.constant_(m.weight, 1.0)
|
| 161 |
|
| 162 |
def forward(self, x, attn_mask=None):
|
| 163 |
-
#pdb.set_trace()
|
| 164 |
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
| 165 |
|
| 166 |
x = self.kv_proj(x)
|
|
@@ -401,7 +395,6 @@ class VisionTransformer(nn.Module):
|
|
| 401 |
act_layer=act_layer,
|
| 402 |
norm_layer=norm_layer,
|
| 403 |
)
|
| 404 |
-
# pdb.set_trace()
|
| 405 |
self.attn_pool = Resampler(
|
| 406 |
grid_size=int(math.sqrt(n_queries)),
|
| 407 |
embed_dim=output_dim,
|
|
@@ -418,14 +411,10 @@ class VisionTransformer(nn.Module):
|
|
| 418 |
)
|
| 419 |
self.ln_post = norm_layer(output_dim)
|
| 420 |
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
|
| 421 |
-
# self.attn_pool2.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
|
| 422 |
|
| 423 |
-
# def initialize_vision_modules(self,lpath):
|
| 424 |
-
# self.attn_pool2[0].load_state_dict(torch.load(lpath))
|
| 425 |
|
| 426 |
def forward(self, x: torch.Tensor):
|
| 427 |
-
|
| 428 |
-
#torch.save(self.attn_pool.state_dict(), '/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth')
|
| 429 |
x = x.to(
|
| 430 |
dtype=self.transformer.get_cast_dtype(),
|
| 431 |
device=self.transformer.get_cast_device(),
|
|
@@ -442,7 +431,6 @@ class VisionTransformer(nn.Module):
|
|
| 442 |
x = x.permute(1, 0, 2) # NLD -> LND
|
| 443 |
x = self.transformer(x)
|
| 444 |
x = x.permute(1, 0, 2) # LND -> NLD
|
| 445 |
-
# pdb.set_trace()
|
| 446 |
src_size = int(math.sqrt(x.shape[1]))
|
| 447 |
x = x.reshape(x.shape[0]//5,5,-1, x.shape[-1])
|
| 448 |
x1 = x[:,4,:,:]
|
|
@@ -454,7 +442,6 @@ class VisionTransformer(nn.Module):
|
|
| 454 |
x1 = self.attn_pool(x1)
|
| 455 |
x = self.post_pro(x)
|
| 456 |
x1 = self.post_pro(x1)
|
| 457 |
-
# return x1
|
| 458 |
return torch.cat([x,x1],dim=1)
|
| 459 |
|
| 460 |
def post_pro(self, x):
|
|
@@ -465,7 +452,7 @@ class VisionTransformer(nn.Module):
|
|
| 465 |
|
| 466 |
def encode(self, image_paths: List[str]):
|
| 467 |
images = []
|
| 468 |
-
|
| 469 |
for image_path in image_paths:
|
| 470 |
try:
|
| 471 |
if image_path.startswith("http://") or image_path.startswith("https://"):
|
|
@@ -474,7 +461,6 @@ class VisionTransformer(nn.Module):
|
|
| 474 |
image = self.image_transform(Image.open(image_path).convert("RGB"))
|
| 475 |
except:
|
| 476 |
image = torch.zeros((3, 448*2, 448*2))
|
| 477 |
-
# pdb.set_trace()
|
| 478 |
images.append(image)
|
| 479 |
images = torch.stack(images, dim=0)
|
| 480 |
windows = sliding_window(images,window_size=(448,448),stride=448)
|
|
|
|
| 25 |
window_cols = (width - window_size[1]) // stride + 1
|
| 26 |
images_448 = F.interpolate(matrix, size=window_size, mode='bicubic')
|
| 27 |
windows = []
|
|
|
|
| 28 |
for i in range(window_rows):
|
| 29 |
windows_col = []
|
| 30 |
for j in range(window_cols):
|
| 31 |
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
|
| 32 |
windows.append(window)
|
|
|
|
| 33 |
windows.append(images_448)
|
| 34 |
images = torch.cat(windows,dim=1)
|
| 35 |
images = images.reshape(b*5,c,window_size[0], window_size[0])
|
|
|
|
| 143 |
self.ln_kv = norm_layer(embed_dim)
|
| 144 |
|
| 145 |
self.apply(self._init_weights)
|
| 146 |
+
|
|
|
|
| 147 |
|
| 148 |
def _init_weights(self, m):
|
|
|
|
|
|
|
| 149 |
if isinstance(m, nn.Linear):
|
| 150 |
trunc_normal_(m.weight, std=.02)
|
| 151 |
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
| 155 |
nn.init.constant_(m.weight, 1.0)
|
| 156 |
|
| 157 |
def forward(self, x, attn_mask=None):
|
|
|
|
| 158 |
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
| 159 |
|
| 160 |
x = self.kv_proj(x)
|
|
|
|
| 395 |
act_layer=act_layer,
|
| 396 |
norm_layer=norm_layer,
|
| 397 |
)
|
|
|
|
| 398 |
self.attn_pool = Resampler(
|
| 399 |
grid_size=int(math.sqrt(n_queries)),
|
| 400 |
embed_dim=output_dim,
|
|
|
|
| 411 |
)
|
| 412 |
self.ln_post = norm_layer(output_dim)
|
| 413 |
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
|
|
|
|
| 414 |
|
|
|
|
|
|
|
| 415 |
|
| 416 |
def forward(self, x: torch.Tensor):
|
| 417 |
+
|
|
|
|
| 418 |
x = x.to(
|
| 419 |
dtype=self.transformer.get_cast_dtype(),
|
| 420 |
device=self.transformer.get_cast_device(),
|
|
|
|
| 431 |
x = x.permute(1, 0, 2) # NLD -> LND
|
| 432 |
x = self.transformer(x)
|
| 433 |
x = x.permute(1, 0, 2) # LND -> NLD
|
|
|
|
| 434 |
src_size = int(math.sqrt(x.shape[1]))
|
| 435 |
x = x.reshape(x.shape[0]//5,5,-1, x.shape[-1])
|
| 436 |
x1 = x[:,4,:,:]
|
|
|
|
| 442 |
x1 = self.attn_pool(x1)
|
| 443 |
x = self.post_pro(x)
|
| 444 |
x1 = self.post_pro(x1)
|
|
|
|
| 445 |
return torch.cat([x,x1],dim=1)
|
| 446 |
|
| 447 |
def post_pro(self, x):
|
|
|
|
| 452 |
|
| 453 |
def encode(self, image_paths: List[str]):
|
| 454 |
images = []
|
| 455 |
+
|
| 456 |
for image_path in image_paths:
|
| 457 |
try:
|
| 458 |
if image_path.startswith("http://") or image_path.startswith("https://"):
|
|
|
|
| 461 |
image = self.image_transform(Image.open(image_path).convert("RGB"))
|
| 462 |
except:
|
| 463 |
image = torch.zeros((3, 448*2, 448*2))
|
|
|
|
| 464 |
images.append(image)
|
| 465 |
images = torch.stack(images, dim=0)
|
| 466 |
windows = sliding_window(images,window_size=(448,448),stride=448)
|