Spaces:
Running
on
Zero
Running
on
Zero
File size: 21,402 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 |
"""
Encoder Class for CroCo & DUSt3R
"""
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
from uniception.models.libs.croco.blocks import Block
from uniception.models.libs.croco.patch_embed import get_patch_embed
from uniception.models.libs.croco.pos_embed import RoPE2D
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
class CroCoEncoder(UniCeptionViTEncoderBase):
"UniCeption CroCov2 Encoder"
def __init__(
self,
name: str,
data_norm_type: str,
patch_embed_cls: str = "PatchEmbedDust3R",
img_size: Union[int, Tuple[int, int]] = (224, 224),
patch_size: int = 16,
enc_embed_dim: int = 1024,
enc_depth: int = 24,
enc_num_heads: int = 16,
mlp_ratio: int = 4,
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6),
pos_embed: str = "RoPE100",
pretrained_checkpoint_path: str = None,
override_checkpoint_attributes: bool = False,
*args,
**kwargs,
):
"""
References: https://github.com/naver/dust3r, https://github.com/naver/croco
Args:
name (str): Name of the encoder.
data_norm_type (str): Input data normalization type.
patch_embed_cls (str, optional): The class to use for patch embedding.
Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed'].
img_size (int, optional): The size of the input image. Defaults to 224.
patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16.
enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768.
enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12.
enc_num_heads (int, optional): The number of encoder heads. Defaults to 12.
mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4.
norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6.
pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['RoPEfreq'].
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None.
"""
# Init the base class
super().__init__(
name=name,
data_norm_type=data_norm_type,
patch_size=patch_size,
*args,
**kwargs,
)
# Init the CroCo Encoder specific attributes
self.patch_embed_cls = patch_embed_cls
self.img_size = img_size
self.enc_embed_dim = enc_embed_dim
self.enc_depth = enc_depth
self.enc_num_heads = enc_num_heads
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.pretrained_checkpoint_path = pretrained_checkpoint_path
self.override_checkpoint_attributes = override_checkpoint_attributes
# Init the positional embedding
self.pos_embed = pos_embed
if pos_embed.startswith("RoPE"): # eg RoPE100
self.enc_pos_embed = None # nothing to add in the encoder with RoPE
self.dec_pos_embed = None # nothing to add in the decoder with RoPE
if RoPE2D is None:
raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
freq = float(pos_embed[len("RoPE") :])
self.rope = RoPE2D(freq=freq)
else:
raise NotImplementedError("Unknown pos_embed " + pos_embed)
# Init the patch embedding
self._set_patch_embed(img_size, patch_size, enc_embed_dim)
# Init the encoder
self._set_encoder(enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, self.rope)
# Initialize random weights
self.initialize_weights()
# Load the pretrained CroCo checkpoint if provided
if pretrained_checkpoint_path:
print(f"Loading pretrained CroCo checkpoint from {pretrained_checkpoint_path}")
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
if not override_checkpoint_attributes:
ckpt_data_norm_type = ckpt["data_norm_type"]
ckpt_patch_embed_cls = ckpt["patch_embed_cls"]
assert (
data_norm_type == ckpt_data_norm_type
), f"Data normalization type {data_norm_type} does not match the checkpoint {ckpt_data_norm_type}."
assert (
patch_embed_cls == ckpt_patch_embed_cls
), f"Patch embedding class {patch_embed_cls} does not match the checkpoint {ckpt_patch_embed_cls}."
else:
print("No pretrained checkpoint provided. Randomly initializing the CroCo encoder.")
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
"Set the patch embedding scheme"
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
def _set_encoder(self, enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, rope):
"Set the encoder"
self.enc_blocks = nn.ModuleList(
[
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=rope)
for _ in range(enc_depth)
]
)
self.enc_norm = norm_layer(enc_embed_dim)
def initialize_weights(self):
"Initialize the weights of the patch embedding and the transformer encoder"
# Patch embedding
self.patch_embed._init_weights()
# Linears and layer norms
self.apply(self._init_weights)
def _init_weights(self, m):
"Initialize the transformer encoder weights"
if isinstance(m, nn.Linear):
# We use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
"""
CroCov2 Encoder Forward Pass
Args:
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
Returns:
ViTEncoderOutput: Output data from the encoder.
"""
# Check image normalization type
self._check_data_normalization_type(encoder_input.data_norm_type)
# Get the true shape of the image for landscape/portrait mode check in patch embedding
batch_size, _, height, width = encoder_input.image.shape
if hasattr(encoder_input, "true_shape"):
true_shape = encoder_input.true_shape
else:
true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1)
# Embed the image into patches
features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape)
# Now apply the transformer encoder and normalization
for blk in self.enc_blocks:
features = blk(features, pos)
features = self.enc_norm(features)
# Resize the features to the expected shape
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
features = features.permute(0, 2, 1)
features = features.reshape(
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
).contiguous()
return ViTEncoderOutput(features=features)
class CroCoIntermediateFeatureReturner(CroCoEncoder, IntermediateFeatureReturner):
"Intermediate Feature Returner for UniCeption CroCo Encoder"
def __init__(
self,
name: str,
data_norm_type: str,
patch_embed_cls: str = "PatchEmbedDust3R",
img_size: Union[int, Tuple[int, int]] = (224, 224),
patch_size: int = 16,
enc_embed_dim: int = 1024,
enc_depth: int = 24,
enc_num_heads: int = 16,
mlp_ratio: int = 4,
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6),
pos_embed: str = "RoPE100",
pretrained_checkpoint_path: str = None,
indices: Optional[Union[int, List[int]]] = None,
norm_intermediate: bool = True,
stop_early: bool = False,
intermediates_only: bool = True,
*args,
**kwargs,
):
"""
Intermediate Feature Returner for the CroCo Encoder.
Args:
name (str): Name of the encoder.
data_norm_type (str): Input data normalization type.
patch_embed_cls (str, optional): The class to use for patch embedding.
Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed'].
img_size (int, optional): The size of the input image. Defaults to 224.
patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16.
enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768.
enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12.
enc_num_heads (int, optional): The number of encoder heads. Defaults to 12.
mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4.
norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6.
pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['cosine', 'RoPE100'].
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None.
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to None. Options:
- None: Return all intermediate layers.
- int: Return the last n layers.
- List[int]: Return the intermediate layers at the specified indices.
norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
stop_early (bool, optional): Whether to stop early. Defaults to False.
intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
"""
# Init the base classes
CroCoEncoder.__init__(
self,
name=name,
data_norm_type=data_norm_type,
patch_embed_cls=patch_embed_cls,
img_size=img_size,
patch_size=patch_size,
enc_embed_dim=enc_embed_dim,
enc_depth=enc_depth,
enc_num_heads=enc_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
pos_embed=pos_embed,
pretrained_checkpoint_path=pretrained_checkpoint_path,
*args,
**kwargs,
)
IntermediateFeatureReturner.__init__(
self,
indices=indices,
norm_intermediate=norm_intermediate,
stop_early=stop_early,
intermediates_only=intermediates_only,
)
def forward(
self, encoder_input: ViTEncoderInput
) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
"""
CroCov2 Encoder Forward Pass with Intermediate Feature Return
Args:
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
Returns:
Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
If `intermediates_only` is True, returns a list of intermediate features.
Otherwise, returns a tuple with the final features and a list of intermediate features.
"""
# Check image normalization type
self._check_data_normalization_type(encoder_input.data_norm_type)
# Get the true shape of the image for landscape/portrait mode check in patch embedding
batch_size, _, height, width = encoder_input.image.shape
if hasattr(encoder_input, "true_shape"):
true_shape = encoder_input.true_shape
else:
true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1)
# Embed the image into patches
features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape)
# Get indices of the intermediate features to return
intermediate_features = []
take_indices, max_index = feature_take_indices(len(self.enc_blocks), self.indices)
# Get the blocks based on early stopping
if torch.jit.is_scripting() or not self.stop_early: # can't slice blocks in torchscript
blocks = self.enc_blocks
else:
blocks = self.enc_blocks[: max_index + 1]
# Now apply the transformer encoder and normalization
for blk_idx, blk in enumerate(blocks):
features = blk(features, pos)
if blk_idx in take_indices:
# Normalize intermediates with final norm layer if enabled
intermediate_features.append(self.enc_norm(features) if self.norm_intermediate else features)
# Reshape the intermediate features and convert to ViTEncoderOutput class
intermediate_features = [
intermediate.permute(0, 2, 1)
.reshape(-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size)
.contiguous()
for intermediate in intermediate_features
]
intermediate_features = [ViTEncoderOutput(features=intermediate) for intermediate in intermediate_features]
# Return only the intermediate features if enabled
if self.intermediates_only:
return intermediate_features
# Normalize and reshape the final features
features = self.enc_norm(features)
# Resize the features to the expected shape
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
features = features.permute(0, 2, 1)
features = features.reshape(
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
).contiguous()
final_features = ViTEncoderOutput(features=features)
return final_features, intermediate_features
if __name__ == "__main__":
# Init the pre-trained CroCo Encoder
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224.pth"
croco_encoder = CroCoEncoder(
name="croco",
data_norm_type="croco",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="PatchEmbedCroCo",
)
# Init the pre-trained DUSt3R CroCo Encoder
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224_DUSt3R_linear.pth"
dust3r_encoder = CroCoEncoder(
name="dust3r_224",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="PatchEmbedDust3R",
)
# Init the pre-trained DUSt3R 512 linear CroCo Encoder
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_linear.pth"
dust3r_encoder_512 = CroCoEncoder(
name="dust3r_512",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="ManyAR_PatchEmbed",
img_size=(512, 512),
)
# Init the pre-trained DUSt3R 512 DPT CroCo Encoder
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth"
dust3r_encoder_512_dpt = CroCoEncoder(
name="dust3r_512_dpt",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="ManyAR_PatchEmbed",
img_size=(512, 512),
)
# Init the MASt3R 512 CroCo Encoder
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_MASt3R.pth"
mast3r_encoder_512 = CroCoEncoder(
name="mast3r_512",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="ManyAR_PatchEmbed",
img_size=(512, 512),
)
print("All CroCo & DUSt3R Encoders have been initialized successfully!")
# Intermediate Feature Returner Tests
print("Running Intermediate Feature Returner Tests...")
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth"
# Run the intermediate feature returner with last-n index
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
name="dust3r_512_dpt",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="ManyAR_PatchEmbed",
img_size=(512, 512),
indices=6, # Last 6 layers
)
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
output = dust3r_intermediate_feature_returner(dummy_input)
assert isinstance(output, list), "Output must be a list of intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
# Run the intermediate feature returner with specific indices
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
name="dust3r_512_dpt",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="ManyAR_PatchEmbed",
img_size=(512, 512),
indices=[0, 2, 4, 6], # Specific layers
)
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
output = dust3r_intermediate_feature_returner(dummy_input)
assert isinstance(output, list), "Output must be a list of intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
# Test the normalizing of intermediate features
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
name="dust3r_512_dpt",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="ManyAR_PatchEmbed",
img_size=(512, 512),
indices=[-1],
norm_intermediate=False,
intermediates_only=False,
)
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
output = dust3r_intermediate_feature_returner(dummy_input)
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
if not isinstance(dust3r_intermediate_feature_returner.enc_norm, torch.nn.Identity):
assert not torch.equal(
output[0].features, output[1][0].features
), "Final features and intermediate features must be different"
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
name="dust3r_512_dpt",
data_norm_type="dust3r",
pretrained_checkpoint_path=pretrained_checkpoint_path,
patch_embed_cls="ManyAR_PatchEmbed",
img_size=(512, 512),
indices=[-1],
norm_intermediate=True,
intermediates_only=False,
)
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
output = dust3r_intermediate_feature_returner(dummy_input)
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
assert torch.equal(
output[0].features, output[1][0].features
), "Final features and intermediate features must be same"
print("All Intermediate Feature Returner Tests have passed successfully!")
|