FlexiCT-2D / models.py
ricklisz123's picture
Upload folder using huggingface_hub
fcefac1 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import logging
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
import torch.nn.functional as F
import torch
import torch.nn.init
from torch import Tensor, nn
from torch.nn.init import trunc_normal_
from .layers import LayerScale, Mlp, PatchEmbed, RMSNorm, RopePositionEmbedding, RopePositionEmbedding3D, SelfAttentionBlock, SwiGLUFFN
from .utils import named_apply
import math
import numpy as np
logger = logging.getLogger("dinov3")
ffn_layer_dict = {
"mlp": Mlp,
"swiglu": SwiGLUFFN,
"swiglu32": partial(SwiGLUFFN, align_to=32),
"swiglu64": partial(SwiGLUFFN, align_to=64),
"swiglu128": partial(SwiGLUFFN, align_to=128),
}
norm_layer_dict = {
"layernorm": partial(nn.LayerNorm, eps=1e-6),
"layernormbf16": partial(nn.LayerNorm, eps=1e-5),
"rmsnorm": RMSNorm,
}
dtype_dict = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
def init_weights_vit(module: nn.Module, name: str = ""):
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if hasattr(module, "bias_mask") and module.bias_mask is not None:
o = module.out_features
module.bias_mask.fill_(1)
module.bias_mask[o // 3 : 2 * o // 3].fill_(0)
if isinstance(module, nn.LayerNorm):
module.reset_parameters()
if isinstance(module, LayerScale):
module.reset_parameters()
if isinstance(module, PatchEmbed):
module.reset_parameters()
if isinstance(module, PatchEmbedND):
module.reset_parameters()
if isinstance(module, PatchEmbed3D):
module.reset_parameters()
if isinstance(module, RMSNorm):
module.reset_parameters()
from torch.distributed._tensor import DTensor, distribute_tensor, Replicate
def _replicate_like(x: torch.Tensor, like: DTensor) -> DTensor:
# Make 'x' a replicated DTensor on the same mesh as 'like'
if isinstance(x, DTensor):
return x
return distribute_tensor(x, device_mesh=like.device_mesh, placements=[Replicate()])
class PatchEmbed3D(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int, int]] = 224,
patch_size: Union[int, Tuple[int, int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HWD = img_size
patch_HWD = patch_size
self.img_size = image_HWD
self.patch_size = patch_HWD
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_HWD, stride=patch_HWD)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, D, H, W = x.shape
x = self.proj(x) # B C D H W
D, H, W = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2) # B HWD C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, D, H, W, self.embed_dim) # B D H W C
return x
def reset_parameters(self):
# weight: [out_c, in_c/groups, kd, kh, kw]
nn.init.kaiming_uniform_(self.proj.weight, a=0.0, mode="fan_in", nonlinearity="linear")
if self.proj.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
bound = 1.0 / math.sqrt(fan_in)
nn.init.uniform_(self.proj.bias, -bound, bound)
def _to_ntuple(n):
def parse(x):
if isinstance(x, (tuple, list)):
assert len(x) == n
return tuple(int(v) for v in x)
return tuple([int(x)] * n)
return parse
_to_2 = _to_ntuple(2)
_to_3 = _to_ntuple(3)
class PatchEmbedND(nn.Module):
def __init__(
self,
dim: int = 2,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer=None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
assert dim in (2, 3)
self.dim = dim
self.img_size = img_size
self.base_patch_size = _to_2(patch_size) if dim == 2 else _to_3(patch_size)
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
if dim == 2:
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=self.base_patch_size,
stride=self.base_patch_size)
else:
self.proj = nn.Conv3d(in_chans, embed_dim,
kernel_size=self.base_patch_size,
stride=self.base_patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self._pinv_cache = {}
self._runtime_patch_size: Optional[Tuple[int, ...]] = None
def set_patch_size(self, new_patch_size: Union[int, Tuple[int, ...]] = 16) -> None:
"""Set runtime patch size (no parameter rebuild)."""
self._runtime_patch_size = (
_to_2(new_patch_size) if self.dim == 2 else _to_3(new_patch_size)
)
def forward(self, x: Tensor) -> Tensor:
psize = self._runtime_patch_size or self.base_patch_size
if psize == self.base_patch_size:
x = self.proj(x)
else:
# Resample to a temporary weight (Tensor or DTensor) and use functional conv
W = self._resample_conv_weight(self.proj.weight, psize)
b = self.proj.bias
if self.dim == 2:
x = F.conv2d(x, W, b, stride=psize, padding=0)
else:
x = F.conv3d(x, W, b, stride=psize, padding=0)
if self.dim == 2:
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(B, H, W, self.embed_dim) # B H W C
else:
B, C, D, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # B DHW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(B, D, H, W, self.embed_dim) # B D H W C
return x
# ---- resampling that is Tensor/DTensor/FSDP-friendly ----
def _resample_conv_weight(self, weight: torch.Tensor, target_size: Tuple[int, ...]) -> torch.Tensor:
old_spatial = tuple(weight.shape[2:])
if old_spatial == tuple(target_size):
return weight
# Build or fetch pseudoinverse (∏old, ∏new) on same device
pinv = self._get_or_build_pinv(old_spatial, target_size, weight.device, torch.float32)
# If weight is DTensor, replicate pinv on same mesh so mm is DTensor x DTensor
if isinstance(weight, DTensor):
pinv = distribute_tensor(pinv, device_mesh=weight.device_mesh, placements=[Replicate()])
c_out, c_in = weight.shape[:2]
old_total = int(np.prod(old_spatial))
w = weight.to(torch.float32).reshape(c_out, c_in, old_total)
w = w @ pinv # -> (c_out, c_in, ∏new)
w = w.reshape(c_out, c_in, *target_size).to(weight.dtype)
return w
def _get_or_build_pinv(
self,
old_size: Tuple[int, ...],
new_size: Tuple[int, ...],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
key = (self.dim, old_size, new_size, device.type)
pinv = self._pinv_cache.get(key)
if pinv is not None:
return pinv.to(device=device, dtype=dtype)
old_total = int(np.prod(old_size))
new_total = int(np.prod(new_size))
eye = torch.eye(old_total, device=device, dtype=dtype)
basis = eye.reshape(old_total, 1, *old_size)
if self.dim == 2:
out = F.interpolate(basis, size=new_size, mode="bicubic", antialias=True, align_corners=False)
R = out.squeeze(1).permute(1, 2, 0).reshape(new_total, old_total)
else:
out = F.interpolate(basis, size=new_size, mode="trilinear", align_corners=False)
R = out.squeeze(1).permute(1, 2, 3, 0).reshape(new_total, old_total)
pinv = torch.linalg.pinv(R).to(dtype) # (∏old, ∏new)
self._pinv_cache[key] = pinv.detach()
return pinv
def reset_parameters(self):
nn.init.kaiming_uniform_(self.proj.weight, a=0.0, mode="fan_in", nonlinearity="linear")
if self.proj.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
bound = 1.0 / math.sqrt(fan_in)
nn.init.uniform_(self.proj.bias, -bound, bound)
class Flexi_CT_Core(nn.Module):
def __init__(
self,
*,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 1,
pos_embed_rope_base: float = 100.0,
pos_embed_rope_min_period: float | None = None,
pos_embed_rope_max_period: float | None = None,
pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
pos_embed_rope_shift_coords: float | None = None,
pos_embed_rope_jitter_coords: float | None = None,
pos_embed_rope_rescale_coords: float | None = None,
pos_embed_rope_dtype: str = "bf16",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 16,
ffn_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path_rate: float = 0.0,
layerscale_init: float | None = None,
norm_layer: str = "layernorm",
ffn_layer: str = "mlp",
ffn_bias: bool = True,
proj_bias: bool = True,
n_storage_tokens: int = 0,
mask_k_bias: bool = False,
untie_cls_and_patch_norms: bool = False,
untie_global_and_local_cls_norm: bool = False,
device: Any | None = None,
**ignored_kwargs,
):
super().__init__()
if len(ignored_kwargs) > 0:
logger.warning(f"Ignored kwargs: {ignored_kwargs}")
del ignored_kwargs
norm_layer_cls = norm_layer_dict[norm_layer]
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.patch_embed_2D = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
flatten_embedding=False,
)
self.patch_embed_3D = PatchEmbed3D(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
flatten_embedding=False,
)
# Optional, but very helpful:
assert embed_dim % (6 * num_heads) == 0, \
f"embed_dim ({embed_dim}) must be divisible by 6*num_heads ({6*num_heads}) for 3D RoPE"
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
self.n_storage_tokens = n_storage_tokens
if self.n_storage_tokens > 0:
self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
logger.info(f"using base={pos_embed_rope_base} for rope new")
logger.info(f"using min_period={pos_embed_rope_min_period} for rope new")
logger.info(f"using max_period={pos_embed_rope_max_period} for rope new")
logger.info(f"using normalize_coords={pos_embed_rope_normalize_coords} for rope new")
logger.info(f"using shift_coords={pos_embed_rope_shift_coords} for rope new")
logger.info(f"using rescale_coords={pos_embed_rope_rescale_coords} for rope new")
logger.info(f"using jitter_coords={pos_embed_rope_jitter_coords} for rope new")
logger.info(f"using dtype={pos_embed_rope_dtype} for rope new")
self.rope_embed_2D = RopePositionEmbedding(
embed_dim=embed_dim,
num_heads=num_heads,
base=pos_embed_rope_base,
min_period=pos_embed_rope_min_period,
max_period=pos_embed_rope_max_period,
normalize_coords=pos_embed_rope_normalize_coords,
shift_coords=pos_embed_rope_shift_coords,
jitter_coords=pos_embed_rope_jitter_coords,
rescale_coords=pos_embed_rope_rescale_coords,
dtype=dtype_dict[pos_embed_rope_dtype],
device=device,
)
self.rope_embed_3D = RopePositionEmbedding3D(
embed_dim=embed_dim,
num_heads=num_heads,
base=pos_embed_rope_base,
min_period=pos_embed_rope_min_period,
max_period=pos_embed_rope_max_period,
normalize_coords=pos_embed_rope_normalize_coords,
shift_coords=pos_embed_rope_shift_coords,
jitter_coords=pos_embed_rope_jitter_coords,
rescale_coords=pos_embed_rope_rescale_coords,
dtype=dtype_dict[pos_embed_rope_dtype],
device=device,
)
logger.info(f"using {ffn_layer} layer as FFN")
ffn_layer_cls = ffn_layer_dict[ffn_layer]
ffn_ratio_sequence = [ffn_ratio] * depth
blocks_list = [
SelfAttentionBlock(
dim=embed_dim,
num_heads=num_heads,
ffn_ratio=ffn_ratio_sequence[i],
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=drop_path_rate,
norm_layer=norm_layer_cls,
act_layer=nn.GELU,
ffn_layer=ffn_layer_cls,
init_values=layerscale_init,
mask_k_bias=mask_k_bias,
device=device,
)
for i in range(depth)
]
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
# This norm is applied to everything, or when untying, to patch and mask tokens.
self.norm = norm_layer_cls(embed_dim)
self.untie_cls_and_patch_norms = untie_cls_and_patch_norms
if untie_cls_and_patch_norms:
# When untying, this norm is applied to CLS tokens and registers.
self.cls_norm = norm_layer_cls(embed_dim)
else:
self.cls_norm = None
self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm
if untie_global_and_local_cls_norm:
# When untying, this norm is applied to local CLS tokens and registers.
# This norm is never used during eval.
self.local_cls_norm = norm_layer_cls(embed_dim)
else:
self.local_cls_norm = None
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device))
def init_weights(self):
self.rope_embed_2D._init_weights()
self.rope_embed_3D._init_weights()
nn.init.normal_(self.cls_token, std=0.02)
if self.n_storage_tokens > 0:
nn.init.normal_(self.storage_tokens, std=0.02)
nn.init.zeros_(self.mask_token)
named_apply(init_weights_vit, self)
def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int]]:
if x.dim() == 5:
x = self.patch_embed_3D(x)
B, D, H, W, _ = x.shape
x = x.flatten(1, 3)
else:
x = self.patch_embed_2D(x)
B, H, W, _ = x.shape
D = None
x = x.flatten(1, 2)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
cls_token = self.cls_token
else:
cls_token = self.cls_token + 0 * self.mask_token
if self.n_storage_tokens > 0:
storage_tokens = self.storage_tokens
else:
storage_tokens = torch.empty(
1,
0,
cls_token.shape[-1],
dtype=cls_token.dtype,
device=cls_token.device,
)
x = torch.cat(
[
cls_token.expand(B, -1, -1),
storage_tokens.expand(B, -1, -1),
x,
],
dim=1,
)
return x, ((D, H, W) if D is not None else (H, W))
def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]:
x = []
rope = []
for t_x, t_masks in zip(x_list, masks_list):
t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks)
x.append(t2_x)
rope.append(hw_tuple)
if len(rope[0]) == 3:
rope_sincos = [self.rope_embed_3D(D=D, H=H, W=W) for D, H, W in rope]
else:
rope_sincos = [self.rope_embed_2D(H=H, W=W) for H, W in rope]
for _, blk in enumerate(self.blocks):
x = blk(x, rope_sincos)
n_storage_tokens = self.n_storage_tokens
norm = self.norm
cls_norm = self.cls_norm
local_cls_norm = self.local_cls_norm
untie_cls_and_patch_norms = self.untie_cls_and_patch_norms
untie_global_and_local_cls_norm = self.untie_global_and_local_cls_norm
training = self.training
all_x = x
output = []
for idx, (x, masks) in enumerate(zip(all_x, masks_list)):
if untie_cls_and_patch_norms or untie_global_and_local_cls_norm:
if untie_global_and_local_cls_norm and training and idx == 1:
# Assume second entry of list corresponds to local crops.
# We only ever apply this during training.
x_norm_cls_reg = local_cls_norm(x[:, : n_storage_tokens + 1])
elif untie_cls_and_patch_norms:
x_norm_cls_reg = cls_norm(x[:, : n_storage_tokens + 1])
else:
x_norm_cls_reg = norm(x[:, : n_storage_tokens + 1])
x_norm_patch = norm(x[:, n_storage_tokens + 1 :])
else:
x_norm = norm(x)
x_norm_cls_reg = x_norm[:, : n_storage_tokens + 1]
x_norm_patch = x_norm[:, n_storage_tokens + 1 :]
output.append(
{
"x_norm_clstoken": x_norm_cls_reg[:, 0],
"x_storage_tokens": x_norm_cls_reg[:, 1:],
"x_norm_patchtokens": x_norm_patch,
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x: Tensor | List[Tensor], masks: Optional[Tensor] = None) -> List[Dict[str, Tensor]]:
if isinstance(x, torch.Tensor):
return self.forward_features_list([x], [masks])[0]
else:
return self.forward_features_list(x, masks)
def _get_intermediate_layers_not_chunked(self, x: Tensor, n: int = 1) -> List[Tensor]:
x, hw_tuple = self.prepare_tokens_with_masks(x)
if len(hw_tuple) == 3:
D, H, W = hw_tuple
rope_sincos = self.rope_embed_3D(D=D, H=H, W=W)
else:
H, W = hw_tuple
rope_sincos = self.rope_embed_2D(H=H, W=W)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x, rope_sincos)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
*,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
return_extra_tokens: bool = False,
norm: bool = True,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]]]:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs_normed = []
for out in outputs:
if self.untie_cls_and_patch_norms:
x_norm_cls_reg = self.cls_norm(out[:, : self.n_storage_tokens + 1])
x_norm_patch = self.norm(out[:, self.n_storage_tokens + 1 :])
outputs_normed.append(torch.cat((x_norm_cls_reg, x_norm_patch), dim=1))
else:
outputs_normed.append(self.norm(out))
outputs = outputs_normed
class_tokens = [out[:, 0] for out in outputs]
extra_tokens = [out[:, 1 : self.n_storage_tokens + 1] for out in outputs]
outputs = [out[:, self.n_storage_tokens + 1 :] for out in outputs]
if reshape:
if x.dim() == 5:
B, _, d, h, w = x.shape
outputs = [
out.reshape(B, d // self.patch_size, h // self.patch_size, w // self.patch_size, -1).permute(0, 4, 1, 2, 3).contiguous()
for out in outputs
]
else:
B, _, h, w = x.shape
outputs = [
out.reshape(B, h // self.patch_size, w // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if not return_class_token and not return_extra_tokens:
return tuple(outputs)
elif return_class_token and not return_extra_tokens:
return tuple(zip(outputs, class_tokens))
elif not return_class_token and return_extra_tokens:
return tuple(zip(outputs, extra_tokens))
elif return_class_token and return_extra_tokens:
return tuple(zip(outputs, class_tokens, extra_tokens))
def forward(self, *args, is_training: bool = False, **kwargs) -> List[Dict[str, Tensor]] | Tensor:
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
@torch.no_grad()
def inflate_patch_embed3d_from_2d(
self,
mode: str = "avg", # "avg" or "center"
) -> None:
"""
Initialize PatchEmbed3D weights by inflating PatchEmbed (2D) weights.
Args:
pe2d: 2D patch embed module with Conv2d `proj` of shape [C_out, C_in, kH, kW].
pe3d: 3D patch embed module with Conv3d `proj` of shape [C_out, C_in, kD, kH, kW].
mode:
- "avg": copy the 2D kernel into each temporal slice and divide by kD (I3D-style).
- "center": copy into the center slice only; others set to 0.
"""
assert isinstance(self.patch_embed_2D.proj, nn.Conv2d) and isinstance(self.patch_embed_3D.proj, nn.Conv3d), \
"pe2d.proj must be Conv2d and pe3d.proj must be Conv3d"
w2 = self.patch_embed_2D.proj.weight.data # [C_out, C_in, kH2, kW2]
b2 = self.patch_embed_2D.proj.bias.data if self.patch_embed_2D.proj.bias is not None else None
w3 = self.patch_embed_3D.proj.weight.data # [C_out, C_in, kD3, kH3, kW3]
b3 = self.patch_embed_3D.proj.bias.data if self.patch_embed_3D.proj.bias is not None else None
C_out2, C_in2, kH2, kW2 = w2.shape
C_out3, C_in3, kD3, kH3, kW3 = w3.shape
# Basic sanity checks
assert C_out2 == C_out3, f"out_channels mismatch: {C_out2} vs {C_out3}"
assert C_in2 == C_in3, f"in_channels mismatch: {C_in2} vs {C_in3}"
assert kH2 == kH3 and kW2 == kW3, \
f"spatial kernel mismatch: (kH,kW)=({kH2},{kW2}) vs ({kH3},{kW3})"
# Inflate: start from zeros
w3.zero_()
if mode == "avg":
# Copy into every temporal slice and average across time
# So the sum over temporal slices reproduces the 2D response.
for t in range(kD3):
w3[:, :, t, :, :].copy_(w2 / kD3)
elif mode == "center":
center = kD3 // 2
w3[:, :, center, :, :].copy_(w2)
else:
raise ValueError(f"Unknown mode='{mode}', expected 'avg' or 'center'.")
# Copy bias if present (identical)
if b2 is not None and b3 is not None:
b3.copy_(b2)
class Flexi_CT_Backbone(Flexi_CT_Core):
def __init__(
self,
*,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 1,
pos_embed_rope_base: float = 100.0,
pos_embed_rope_min_period: float | None = None,
pos_embed_rope_max_period: float | None = None,
pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
pos_embed_rope_shift_coords: float | None = None,
pos_embed_rope_jitter_coords: float | None = None,
pos_embed_rope_rescale_coords: float | None = None,
pos_embed_rope_dtype: str = "bf16",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 16,
ffn_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path_rate: float = 0.0,
layerscale_init: float | None = None,
norm_layer: str = "layernorm",
ffn_layer: str = "mlp",
ffn_bias: bool = True,
proj_bias: bool = True,
n_storage_tokens: int = 0,
mask_k_bias: bool = False,
untie_cls_and_patch_norms: bool = False,
untie_global_and_local_cls_norm: bool = False,
device: Any | None = None,
**ignored_kwargs,
):
# Call parent class constructor with all required parameters
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
pos_embed_rope_base=pos_embed_rope_base,
pos_embed_rope_min_period=pos_embed_rope_min_period,
pos_embed_rope_max_period=pos_embed_rope_max_period,
pos_embed_rope_normalize_coords=pos_embed_rope_normalize_coords,
pos_embed_rope_shift_coords=pos_embed_rope_shift_coords,
pos_embed_rope_jitter_coords=pos_embed_rope_jitter_coords,
pos_embed_rope_rescale_coords=pos_embed_rope_rescale_coords,
pos_embed_rope_dtype=pos_embed_rope_dtype,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
ffn_ratio=ffn_ratio,
qkv_bias=qkv_bias,
drop_path_rate=drop_path_rate,
layerscale_init=layerscale_init,
norm_layer=norm_layer,
ffn_layer=ffn_layer,
ffn_bias=ffn_bias,
proj_bias=proj_bias,
n_storage_tokens=n_storage_tokens,
mask_k_bias=mask_k_bias,
untie_cls_and_patch_norms=untie_cls_and_patch_norms,
untie_global_and_local_cls_norm=untie_global_and_local_cls_norm,
device=device,
**ignored_kwargs,
)
self.patch_embed_2D = PatchEmbedND(
dim = 2,
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
flatten_embedding=False,
)
self.patch_embed_3D = PatchEmbedND(
dim = 3,
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
flatten_embedding=False,
)
class Flexi_CT_VLM_Module(nn.Module):
def __init__(
self,
*,
vision_model: nn.Module,
text_model: nn.Module,
embed_dim = 1024,
device: Any | None = None,
**ignored_kwargs,
):
super().__init__()
# Call parent class constructor with all required parameters
self.vision_model = vision_model
self.text_model = text_model
self.logit_scale = nn.Parameter(torch.empty(1))
self.vlm_embed_dim = embed_dim
self.vlm_vision_projection = nn.Linear(2*vision_model.embed_dim, self.vlm_embed_dim, bias=False)
self.device = device
def forward(self, images: torch.Tensor, text) -> torch.Tensor:
vision_features = self.vision_model(images, is_training = True)
cls_token = vision_features["x_norm_clstoken"] # [B, D]
patch_tokens = vision_features["x_norm_patchtokens"] # [B, P, D]
# Mean pool patch tokens (like DINOTxt)
mean_patch_token = torch.mean(patch_tokens, dim=1) # [B, D]
# Concatenate CLS + mean(patch) along channel dimension (like DINOTxt)
image_features = torch.cat([cls_token, mean_patch_token], dim=-1) # [B, 2*D]
# Project vision features to VLM embedding space
image_features = self.vlm_vision_projection(image_features) # [B, vlm_embed_dim]
# Normalize image features
image_features = F.normalize(image_features, dim=-1)
text_features = self.text_model(**text)
# Normalize text features
text_features = F.normalize(text_features, dim=-1)
return self.logit_scale.exp(), image_features, text_features
def flexi_ct_backbone_base(patch_size=8, **kwargs):
model = Flexi_CT_Backbone(
patch_size=patch_size,
embed_dim=864,
depth=16,
num_heads=12,
ffn_ratio=4,
**kwargs,
)
return model