infinity1096
initial commit
c8b42eb
"""
Encoder Class for RADIO (Nvidia)
"""
from typing import List, Optional, Tuple, Union
import torch
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner
class RADIOEncoder(UniCeptionViTEncoderBase):
"UniCeption RADIO Encoder"
def __init__(
self,
name: str,
data_norm_type: str = "radio",
patch_size: int = 16,
model_version: str = "radio_v2.5-l",
pretrained_checkpoint_path: str = None,
eradio_input_shape: Optional[tuple] = None,
torch_hub_force_reload: bool = False,
keep_first_n_layers: Optional[int] = None,
*args,
**kwargs,
):
"""
RADIO Encoder for extracting spatial features from images.
Args:
name (str): Name of the encoder.
data_norm_type (str): Image normalization type. Default: "radio"
patch_size (int): Patch size for the encoder. Default: 16
model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. Default: None
eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
keep_first_n_layers (Optional[int]): Number of layers to keep from the pretrained model. Default: None
"""
# Init the base class
super().__init__(
name=name,
data_norm_type=data_norm_type,
patch_size=patch_size,
*args,
**kwargs,
)
# Init the RADIO Encoder specific attributes
self.model_version = model_version
self.enc_embed_dim = {
"radio_v2.5-b": 768,
"radio_v2.5-l": 1024,
"radio_v2.5-h": 1280,
"radio_v2.5-g": 1536,
"e-radio_v2": 1536,
}[self.model_version]
if self.model_version == "radio_v2.5-g":
assert patch_size == 14, "Patch size must be 14 for RADIO v2.5-g"
else:
assert patch_size == 16, "Patch size must be 16 for all other versions of RADIO"
# Load the pretrained RADIO model from torch hub
print(f"Loading pretrained {self.model_version} from torch hub")
try: # Requires internet access
self.model = torch.hub.load(
"NVlabs/RADIO",
"radio_model",
version=self.model_version,
progress=True,
skip_validation=True,
force_reload=torch_hub_force_reload,
)
except: # Load from cache
self.model = torch.hub.load(
"NVlabs/RADIO",
"radio_model",
version=self.model_version,
progress=True,
skip_validation=True,
)
# Delete the excess blocks if keep_first_n_layers is specified
if keep_first_n_layers is not None:
assert keep_first_n_layers < len(
self.model.model.blocks
), "keep_first_n_layers must be less than the number of blocks"
print(f"Keeping only the first {keep_first_n_layers} layers of the model")
self.model.model.blocks = torch.nn.ModuleList(self.model.model.blocks[:keep_first_n_layers])
# Set the optimal window size for E-RADIO models
if "e-radio" in self.model_version:
assert eradio_input_shape is not None, "Input shape (height, width) must be provided for E-RADIO models"
self.model.model.set_optimal_window_size(eradio_input_shape)
# Load the custom pretrained checkpoint if provided
if pretrained_checkpoint_path is not None:
print(f"Loading custom pretrained RADIO checkpoint from {pretrained_checkpoint_path}")
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
print(self.load_state_dict(ckpt["model"]))
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
"""
RADIO Encoder Forward Pass
Args:
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
Returns:
ViTEncoderOutput: Output data from the encoder.
"""
# Check image normalization type
self._check_data_normalization_type(encoder_input.data_norm_type)
# Check the dtype and shape of the input image
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
batch_size, channels, height, width = encoder_input.image.shape
assert channels == 3, "Input must have 3 channels"
assert (
height % self.patch_size == 0 and width % self.patch_size == 0
), f"Input shape must be divisible by patch size: {self.patch_size}"
# Forward pass throught the RADIO encoder
summary, features = self.model(encoder_input.image)
# Resize the features to the expected shape
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
features = features.permute(0, 2, 1)
features = features.reshape(
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
).contiguous()
return ViTEncoderOutput(features=features)
class RADIOIntermediateFeatureReturner(RADIOEncoder, IntermediateFeatureReturner):
"Intermediate Feature Returner for UniCeption RADIO Encoder"
def __init__(
self,
name: str,
data_norm_type: str = "radio",
patch_size: int = 16,
model_version: str = "radio_v2.5-l",
pretrained_checkpoint_path: str = None,
eradio_input_shape: Optional[tuple] = None,
indices: Union[int, List[int]] = [-1],
norm_intermediate: bool = True,
stop_early: bool = False,
intermediates_only: bool = True,
feature_adaptor: Optional[str] = None,
keep_first_n_layers: Optional[int] = None,
*args,
**kwargs,
):
"""
Intermediate Feature Returner for the RADIO Encoder.
Args:
name (str): Name of the encoder.
data_norm_type (str): Image normalization type. Default: "radio"
patch_size (int): Patch size for the encoder. Default: 16
model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO.
eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to [-1]. Options:
- int: Return the last n layers.
- List[int]: Return the intermediate layers at the specified indices.
norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
stop_early (bool, optional): Whether to stop early. Defaults to False.
intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
feature_adaptor (Optional[str], optional): Feature adaptor to use. Defaults to None. Currently supported: "dino_v2".
keep_first_n_layers (Optional[int], optional): Number of layers to keep from the pretrained model. Defaults to None.
"""
# Init the base classes
RADIOEncoder.__init__(
self,
name=name,
data_norm_type=data_norm_type,
patch_size=patch_size,
model_version=model_version,
pretrained_checkpoint_path=pretrained_checkpoint_path,
eradio_input_shape=eradio_input_shape,
keep_first_n_layers=keep_first_n_layers,
*args,
**kwargs,
)
IntermediateFeatureReturner.__init__(
self,
indices=indices,
norm_intermediate=norm_intermediate,
stop_early=stop_early,
intermediates_only=intermediates_only,
)
# Convert indices to absolute indices if indices is None
if self.indices is None:
self.indices = list(range(len(self.model.model.blocks)))
self.feature_adaptor = feature_adaptor
if self.feature_adaptor is None:
pass
elif self.feature_adaptor == "dino_v2":
# Initialize a dummy radio encoder with the adaptor setting
dummy_model = torch.hub.load(
"NVlabs/RADIO",
"radio_model",
version=self.model_version,
progress=True,
skip_validation=True,
adaptor_names="dino_v2",
)
# Extract its feature converter weights
self.spatial_feature_converter = dummy_model.adaptors["dino_v2"].feat_mlp
# Update the embedding dimension because the features have been projected
self.enc_embed_dim = self.spatial_feature_converter.final[-1].out_features
del dummy_model
else:
raise ValueError("Unsupported feature adaptor. Supported: dino_v2")
def forward(
self, encoder_input: ViTEncoderInput
) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
"""
RADIO Encoder Forward Pass with Intermediate Feature Return
Args:
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
Returns:
Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
If `intermediates_only` is True, returns a list of intermediate features.
Otherwise, returns a tuple with the final features and a list of intermediate features.
"""
# Check image normalization type
self._check_data_normalization_type(encoder_input.data_norm_type)
# Check the dtype and shape of the input image
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
batch_size, channels, height, width = encoder_input.image.shape
assert channels == 3, "Input must have 3 channels"
assert (
height % self.patch_size == 0 and width % self.patch_size == 0
), f"Input shape must be divisible by patch size: {self.patch_size}"
# Extract the final features and intermediate features accordingly
model_outputs = self.model.forward_intermediates(
encoder_input.image,
indices=self.indices,
return_prefix_tokens=False,
norm=self.norm_intermediate,
stop_early=self.stop_early,
output_fmt="NLC",
intermediates_only=self.intermediates_only,
)
# Extract the final features and intermediate features accordingly
final_features, intermediate_features = None, None
if self.intermediates_only:
intermediate_features = model_outputs
else:
final_features = model_outputs[0].features.contiguous()
intermediate_features = model_outputs[1]
# Optionally convert the features using the feature adaptor
Hp, Wp = height // self.patch_size, width // self.patch_size
# Convert final features
if final_features is not None:
if self.feature_adaptor is not None:
final_features = self.spatial_feature_converter(final_features)
# Convert to BCHW and package
final_features = final_features.view(batch_size, Hp, Wp, -1).permute(0, 3, 1, 2)
final_features = ViTEncoderOutput(features=final_features)
# Convert intermediate features
if intermediate_features is not None:
num_intermediate = len(intermediate_features)
all_intermediate_feats_tensor = torch.cat(intermediate_features, dim=0)
if self.feature_adaptor is not None:
all_intermediate_feats_tensor = self.spatial_feature_converter(all_intermediate_feats_tensor)
# Convert to BCHW
all_intermediate_feats_tensor = all_intermediate_feats_tensor.view(
num_intermediate * batch_size, Hp, Wp, -1
).permute(0, 3, 1, 2)
all_intermediate_feats = torch.chunk(all_intermediate_feats_tensor, num_intermediate, dim=0)
intermediate_features = [ViTEncoderOutput(features=x) for x in all_intermediate_feats]
# Return the final features and intermediate features accordingly
if self.intermediates_only:
return intermediate_features
else:
return final_features, intermediate_features
if __name__ == "__main__":
# Init different versions of the RADIO Encoder
for model_version in ["radio_v2.5-b", "radio_v2.5-l"]:
radio_encoder = RADIOEncoder(name="RADIOv2.5", model_version=model_version)
# Init the E-RADIO Encoder
eradio_input_shape = (512, 512)
eradio_encoder = RADIOEncoder(name="E-RADIO", model_version="e-radio_v2", eradio_input_shape=eradio_input_shape)
print("All RADIO Encoders have been initialized successfully!")
# Intermediate Feature Returner Tests
print("Running Intermediate Feature Returner Tests...")
# Run the intermediate feature returner with last-n index
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
name="RADIOv2.5", model_version="radio_v2.5-b", indices=6
) # Last 6 layers
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
output = radio_intermediate_feature_returner(dummy_input)
assert isinstance(output, list), "Output must be a list of intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
# Run the intermediate feature returner with specific indices
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
name="RADIOv2.5", model_version="radio_v2.5-b", indices=[0, 2, 4, 6]
) # Specific layers
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
output = radio_intermediate_feature_returner(dummy_input)
assert isinstance(output, list), "Output must be a list of intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
# Test the normalizing of intermediate features
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=False, intermediates_only=False
) # Do not normalize
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
output = radio_intermediate_feature_returner(dummy_input)
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
if not isinstance(radio_intermediate_feature_returner.model.model.norm, torch.nn.Identity):
assert not torch.equal(
output[0].features, output[1][0].features
), "Final features and intermediate features must be different"
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=True, intermediates_only=False
)
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
output = radio_intermediate_feature_returner(dummy_input)
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
assert torch.equal(
output[0].features, output[1][0].features
), "Final features and intermediate features must be same"
print("All Intermediate Feature Returner Tests have passed successfully!")