Image Feature Extraction
Transformers
Safetensors
flexict
feature-extraction
medical-imaging
ct
vision
custom_code
Instructions to use ricklisz123/FlexiCT-2D with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ricklisz123/FlexiCT-2D with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-feature-extraction", model="ricklisz123/FlexiCT-2D", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ricklisz123/FlexiCT-2D", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This software may be used and distributed in accordance with | |
| # the terms of the DINOv3 License Agreement. | |
| import logging | |
| from functools import partial | |
| from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.nn.init | |
| from torch import Tensor, nn | |
| from torch.nn.init import trunc_normal_ | |
| from .layers import LayerScale, Mlp, PatchEmbed, RMSNorm, RopePositionEmbedding, RopePositionEmbedding3D, SelfAttentionBlock, SwiGLUFFN | |
| from .utils import named_apply | |
| import math | |
| import numpy as np | |
| logger = logging.getLogger("dinov3") | |
| ffn_layer_dict = { | |
| "mlp": Mlp, | |
| "swiglu": SwiGLUFFN, | |
| "swiglu32": partial(SwiGLUFFN, align_to=32), | |
| "swiglu64": partial(SwiGLUFFN, align_to=64), | |
| "swiglu128": partial(SwiGLUFFN, align_to=128), | |
| } | |
| norm_layer_dict = { | |
| "layernorm": partial(nn.LayerNorm, eps=1e-6), | |
| "layernormbf16": partial(nn.LayerNorm, eps=1e-5), | |
| "rmsnorm": RMSNorm, | |
| } | |
| dtype_dict = { | |
| "fp32": torch.float32, | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| } | |
| def init_weights_vit(module: nn.Module, name: str = ""): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.trunc_normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| if hasattr(module, "bias_mask") and module.bias_mask is not None: | |
| o = module.out_features | |
| module.bias_mask.fill_(1) | |
| module.bias_mask[o // 3 : 2 * o // 3].fill_(0) | |
| if isinstance(module, nn.LayerNorm): | |
| module.reset_parameters() | |
| if isinstance(module, LayerScale): | |
| module.reset_parameters() | |
| if isinstance(module, PatchEmbed): | |
| module.reset_parameters() | |
| if isinstance(module, PatchEmbedND): | |
| module.reset_parameters() | |
| if isinstance(module, PatchEmbed3D): | |
| module.reset_parameters() | |
| if isinstance(module, RMSNorm): | |
| module.reset_parameters() | |
| from torch.distributed._tensor import DTensor, distribute_tensor, Replicate | |
| def _replicate_like(x: torch.Tensor, like: DTensor) -> DTensor: | |
| # Make 'x' a replicated DTensor on the same mesh as 'like' | |
| if isinstance(x, DTensor): | |
| return x | |
| return distribute_tensor(x, device_mesh=like.device_mesh, placements=[Replicate()]) | |
| class PatchEmbed3D(nn.Module): | |
| """ | |
| 2D image to patch embedding: (B,C,H,W) -> (B,N,D) | |
| Args: | |
| img_size: Image size. | |
| patch_size: Patch token size. | |
| in_chans: Number of input image channels. | |
| embed_dim: Number of linear projection output channels. | |
| norm_layer: Normalization layer. | |
| """ | |
| def __init__( | |
| self, | |
| img_size: Union[int, Tuple[int, int, int]] = 224, | |
| patch_size: Union[int, Tuple[int, int, int]] = 16, | |
| in_chans: int = 3, | |
| embed_dim: int = 768, | |
| norm_layer = None, | |
| flatten_embedding: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| image_HWD = img_size | |
| patch_HWD = patch_size | |
| self.img_size = image_HWD | |
| self.patch_size = patch_HWD | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.flatten_embedding = flatten_embedding | |
| self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_HWD, stride=patch_HWD) | |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
| def forward(self, x: Tensor) -> Tensor: | |
| _, _, D, H, W = x.shape | |
| x = self.proj(x) # B C D H W | |
| D, H, W = x.size(2), x.size(3), x.size(4) | |
| x = x.flatten(2).transpose(1, 2) # B HWD C | |
| x = self.norm(x) | |
| if not self.flatten_embedding: | |
| x = x.reshape(-1, D, H, W, self.embed_dim) # B D H W C | |
| return x | |
| def reset_parameters(self): | |
| # weight: [out_c, in_c/groups, kd, kh, kw] | |
| nn.init.kaiming_uniform_(self.proj.weight, a=0.0, mode="fan_in", nonlinearity="linear") | |
| if self.proj.bias is not None: | |
| fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.proj.weight) | |
| bound = 1.0 / math.sqrt(fan_in) | |
| nn.init.uniform_(self.proj.bias, -bound, bound) | |
| def _to_ntuple(n): | |
| def parse(x): | |
| if isinstance(x, (tuple, list)): | |
| assert len(x) == n | |
| return tuple(int(v) for v in x) | |
| return tuple([int(x)] * n) | |
| return parse | |
| _to_2 = _to_ntuple(2) | |
| _to_3 = _to_ntuple(3) | |
| class PatchEmbedND(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int = 2, | |
| img_size: int = 224, | |
| patch_size: int = 16, | |
| in_chans: int = 3, | |
| embed_dim: int = 768, | |
| norm_layer=None, | |
| flatten_embedding: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| assert dim in (2, 3) | |
| self.dim = dim | |
| self.img_size = img_size | |
| self.base_patch_size = _to_2(patch_size) if dim == 2 else _to_3(patch_size) | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.flatten_embedding = flatten_embedding | |
| if dim == 2: | |
| self.proj = nn.Conv2d(in_chans, embed_dim, | |
| kernel_size=self.base_patch_size, | |
| stride=self.base_patch_size) | |
| else: | |
| self.proj = nn.Conv3d(in_chans, embed_dim, | |
| kernel_size=self.base_patch_size, | |
| stride=self.base_patch_size) | |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
| self._pinv_cache = {} | |
| self._runtime_patch_size: Optional[Tuple[int, ...]] = None | |
| def set_patch_size(self, new_patch_size: Union[int, Tuple[int, ...]] = 16) -> None: | |
| """Set runtime patch size (no parameter rebuild).""" | |
| self._runtime_patch_size = ( | |
| _to_2(new_patch_size) if self.dim == 2 else _to_3(new_patch_size) | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| psize = self._runtime_patch_size or self.base_patch_size | |
| if psize == self.base_patch_size: | |
| x = self.proj(x) | |
| else: | |
| # Resample to a temporary weight (Tensor or DTensor) and use functional conv | |
| W = self._resample_conv_weight(self.proj.weight, psize) | |
| b = self.proj.bias | |
| if self.dim == 2: | |
| x = F.conv2d(x, W, b, stride=psize, padding=0) | |
| else: | |
| x = F.conv3d(x, W, b, stride=psize, padding=0) | |
| if self.dim == 2: | |
| B, C, H, W = x.shape | |
| x = x.flatten(2).transpose(1, 2) # B HW C | |
| x = self.norm(x) | |
| if not self.flatten_embedding: | |
| x = x.reshape(B, H, W, self.embed_dim) # B H W C | |
| else: | |
| B, C, D, H, W = x.shape | |
| x = x.flatten(2).transpose(1, 2) # B DHW C | |
| x = self.norm(x) | |
| if not self.flatten_embedding: | |
| x = x.reshape(B, D, H, W, self.embed_dim) # B D H W C | |
| return x | |
| # ---- resampling that is Tensor/DTensor/FSDP-friendly ---- | |
| def _resample_conv_weight(self, weight: torch.Tensor, target_size: Tuple[int, ...]) -> torch.Tensor: | |
| old_spatial = tuple(weight.shape[2:]) | |
| if old_spatial == tuple(target_size): | |
| return weight | |
| # Build or fetch pseudoinverse (∏old, ∏new) on same device | |
| pinv = self._get_or_build_pinv(old_spatial, target_size, weight.device, torch.float32) | |
| # If weight is DTensor, replicate pinv on same mesh so mm is DTensor x DTensor | |
| if isinstance(weight, DTensor): | |
| pinv = distribute_tensor(pinv, device_mesh=weight.device_mesh, placements=[Replicate()]) | |
| c_out, c_in = weight.shape[:2] | |
| old_total = int(np.prod(old_spatial)) | |
| w = weight.to(torch.float32).reshape(c_out, c_in, old_total) | |
| w = w @ pinv # -> (c_out, c_in, ∏new) | |
| w = w.reshape(c_out, c_in, *target_size).to(weight.dtype) | |
| return w | |
| def _get_or_build_pinv( | |
| self, | |
| old_size: Tuple[int, ...], | |
| new_size: Tuple[int, ...], | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> torch.Tensor: | |
| key = (self.dim, old_size, new_size, device.type) | |
| pinv = self._pinv_cache.get(key) | |
| if pinv is not None: | |
| return pinv.to(device=device, dtype=dtype) | |
| old_total = int(np.prod(old_size)) | |
| new_total = int(np.prod(new_size)) | |
| eye = torch.eye(old_total, device=device, dtype=dtype) | |
| basis = eye.reshape(old_total, 1, *old_size) | |
| if self.dim == 2: | |
| out = F.interpolate(basis, size=new_size, mode="bicubic", antialias=True, align_corners=False) | |
| R = out.squeeze(1).permute(1, 2, 0).reshape(new_total, old_total) | |
| else: | |
| out = F.interpolate(basis, size=new_size, mode="trilinear", align_corners=False) | |
| R = out.squeeze(1).permute(1, 2, 3, 0).reshape(new_total, old_total) | |
| pinv = torch.linalg.pinv(R).to(dtype) # (∏old, ∏new) | |
| self._pinv_cache[key] = pinv.detach() | |
| return pinv | |
| def reset_parameters(self): | |
| nn.init.kaiming_uniform_(self.proj.weight, a=0.0, mode="fan_in", nonlinearity="linear") | |
| if self.proj.bias is not None: | |
| fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.proj.weight) | |
| bound = 1.0 / math.sqrt(fan_in) | |
| nn.init.uniform_(self.proj.bias, -bound, bound) | |
| class Flexi_CT_Core(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| img_size: int = 224, | |
| patch_size: int = 16, | |
| in_chans: int = 1, | |
| pos_embed_rope_base: float = 100.0, | |
| pos_embed_rope_min_period: float | None = None, | |
| pos_embed_rope_max_period: float | None = None, | |
| pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", | |
| pos_embed_rope_shift_coords: float | None = None, | |
| pos_embed_rope_jitter_coords: float | None = None, | |
| pos_embed_rope_rescale_coords: float | None = None, | |
| pos_embed_rope_dtype: str = "bf16", | |
| embed_dim: int = 768, | |
| depth: int = 12, | |
| num_heads: int = 16, | |
| ffn_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| drop_path_rate: float = 0.0, | |
| layerscale_init: float | None = None, | |
| norm_layer: str = "layernorm", | |
| ffn_layer: str = "mlp", | |
| ffn_bias: bool = True, | |
| proj_bias: bool = True, | |
| n_storage_tokens: int = 0, | |
| mask_k_bias: bool = False, | |
| untie_cls_and_patch_norms: bool = False, | |
| untie_global_and_local_cls_norm: bool = False, | |
| device: Any | None = None, | |
| **ignored_kwargs, | |
| ): | |
| super().__init__() | |
| if len(ignored_kwargs) > 0: | |
| logger.warning(f"Ignored kwargs: {ignored_kwargs}") | |
| del ignored_kwargs | |
| norm_layer_cls = norm_layer_dict[norm_layer] | |
| self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |
| self.n_blocks = depth | |
| self.num_heads = num_heads | |
| self.patch_size = patch_size | |
| self.patch_embed_2D = PatchEmbed( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| flatten_embedding=False, | |
| ) | |
| self.patch_embed_3D = PatchEmbed3D( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| flatten_embedding=False, | |
| ) | |
| # Optional, but very helpful: | |
| assert embed_dim % (6 * num_heads) == 0, \ | |
| f"embed_dim ({embed_dim}) must be divisible by 6*num_heads ({6*num_heads}) for 3D RoPE" | |
| self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) | |
| self.n_storage_tokens = n_storage_tokens | |
| if self.n_storage_tokens > 0: | |
| self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) | |
| logger.info(f"using base={pos_embed_rope_base} for rope new") | |
| logger.info(f"using min_period={pos_embed_rope_min_period} for rope new") | |
| logger.info(f"using max_period={pos_embed_rope_max_period} for rope new") | |
| logger.info(f"using normalize_coords={pos_embed_rope_normalize_coords} for rope new") | |
| logger.info(f"using shift_coords={pos_embed_rope_shift_coords} for rope new") | |
| logger.info(f"using rescale_coords={pos_embed_rope_rescale_coords} for rope new") | |
| logger.info(f"using jitter_coords={pos_embed_rope_jitter_coords} for rope new") | |
| logger.info(f"using dtype={pos_embed_rope_dtype} for rope new") | |
| self.rope_embed_2D = RopePositionEmbedding( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| base=pos_embed_rope_base, | |
| min_period=pos_embed_rope_min_period, | |
| max_period=pos_embed_rope_max_period, | |
| normalize_coords=pos_embed_rope_normalize_coords, | |
| shift_coords=pos_embed_rope_shift_coords, | |
| jitter_coords=pos_embed_rope_jitter_coords, | |
| rescale_coords=pos_embed_rope_rescale_coords, | |
| dtype=dtype_dict[pos_embed_rope_dtype], | |
| device=device, | |
| ) | |
| self.rope_embed_3D = RopePositionEmbedding3D( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| base=pos_embed_rope_base, | |
| min_period=pos_embed_rope_min_period, | |
| max_period=pos_embed_rope_max_period, | |
| normalize_coords=pos_embed_rope_normalize_coords, | |
| shift_coords=pos_embed_rope_shift_coords, | |
| jitter_coords=pos_embed_rope_jitter_coords, | |
| rescale_coords=pos_embed_rope_rescale_coords, | |
| dtype=dtype_dict[pos_embed_rope_dtype], | |
| device=device, | |
| ) | |
| logger.info(f"using {ffn_layer} layer as FFN") | |
| ffn_layer_cls = ffn_layer_dict[ffn_layer] | |
| ffn_ratio_sequence = [ffn_ratio] * depth | |
| blocks_list = [ | |
| SelfAttentionBlock( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| ffn_ratio=ffn_ratio_sequence[i], | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| ffn_bias=ffn_bias, | |
| drop_path=drop_path_rate, | |
| norm_layer=norm_layer_cls, | |
| act_layer=nn.GELU, | |
| ffn_layer=ffn_layer_cls, | |
| init_values=layerscale_init, | |
| mask_k_bias=mask_k_bias, | |
| device=device, | |
| ) | |
| for i in range(depth) | |
| ] | |
| self.chunked_blocks = False | |
| self.blocks = nn.ModuleList(blocks_list) | |
| # This norm is applied to everything, or when untying, to patch and mask tokens. | |
| self.norm = norm_layer_cls(embed_dim) | |
| self.untie_cls_and_patch_norms = untie_cls_and_patch_norms | |
| if untie_cls_and_patch_norms: | |
| # When untying, this norm is applied to CLS tokens and registers. | |
| self.cls_norm = norm_layer_cls(embed_dim) | |
| else: | |
| self.cls_norm = None | |
| self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm | |
| if untie_global_and_local_cls_norm: | |
| # When untying, this norm is applied to local CLS tokens and registers. | |
| # This norm is never used during eval. | |
| self.local_cls_norm = norm_layer_cls(embed_dim) | |
| else: | |
| self.local_cls_norm = None | |
| self.head = nn.Identity() | |
| self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device)) | |
| def init_weights(self): | |
| self.rope_embed_2D._init_weights() | |
| self.rope_embed_3D._init_weights() | |
| nn.init.normal_(self.cls_token, std=0.02) | |
| if self.n_storage_tokens > 0: | |
| nn.init.normal_(self.storage_tokens, std=0.02) | |
| nn.init.zeros_(self.mask_token) | |
| named_apply(init_weights_vit, self) | |
| def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int]]: | |
| if x.dim() == 5: | |
| x = self.patch_embed_3D(x) | |
| B, D, H, W, _ = x.shape | |
| x = x.flatten(1, 3) | |
| else: | |
| x = self.patch_embed_2D(x) | |
| B, H, W, _ = x.shape | |
| D = None | |
| x = x.flatten(1, 2) | |
| if masks is not None: | |
| x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) | |
| cls_token = self.cls_token | |
| else: | |
| cls_token = self.cls_token + 0 * self.mask_token | |
| if self.n_storage_tokens > 0: | |
| storage_tokens = self.storage_tokens | |
| else: | |
| storage_tokens = torch.empty( | |
| 1, | |
| 0, | |
| cls_token.shape[-1], | |
| dtype=cls_token.dtype, | |
| device=cls_token.device, | |
| ) | |
| x = torch.cat( | |
| [ | |
| cls_token.expand(B, -1, -1), | |
| storage_tokens.expand(B, -1, -1), | |
| x, | |
| ], | |
| dim=1, | |
| ) | |
| return x, ((D, H, W) if D is not None else (H, W)) | |
| def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]: | |
| x = [] | |
| rope = [] | |
| for t_x, t_masks in zip(x_list, masks_list): | |
| t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks) | |
| x.append(t2_x) | |
| rope.append(hw_tuple) | |
| if len(rope[0]) == 3: | |
| rope_sincos = [self.rope_embed_3D(D=D, H=H, W=W) for D, H, W in rope] | |
| else: | |
| rope_sincos = [self.rope_embed_2D(H=H, W=W) for H, W in rope] | |
| for _, blk in enumerate(self.blocks): | |
| x = blk(x, rope_sincos) | |
| n_storage_tokens = self.n_storage_tokens | |
| norm = self.norm | |
| cls_norm = self.cls_norm | |
| local_cls_norm = self.local_cls_norm | |
| untie_cls_and_patch_norms = self.untie_cls_and_patch_norms | |
| untie_global_and_local_cls_norm = self.untie_global_and_local_cls_norm | |
| training = self.training | |
| all_x = x | |
| output = [] | |
| for idx, (x, masks) in enumerate(zip(all_x, masks_list)): | |
| if untie_cls_and_patch_norms or untie_global_and_local_cls_norm: | |
| if untie_global_and_local_cls_norm and training and idx == 1: | |
| # Assume second entry of list corresponds to local crops. | |
| # We only ever apply this during training. | |
| x_norm_cls_reg = local_cls_norm(x[:, : n_storage_tokens + 1]) | |
| elif untie_cls_and_patch_norms: | |
| x_norm_cls_reg = cls_norm(x[:, : n_storage_tokens + 1]) | |
| else: | |
| x_norm_cls_reg = norm(x[:, : n_storage_tokens + 1]) | |
| x_norm_patch = norm(x[:, n_storage_tokens + 1 :]) | |
| else: | |
| x_norm = norm(x) | |
| x_norm_cls_reg = x_norm[:, : n_storage_tokens + 1] | |
| x_norm_patch = x_norm[:, n_storage_tokens + 1 :] | |
| output.append( | |
| { | |
| "x_norm_clstoken": x_norm_cls_reg[:, 0], | |
| "x_storage_tokens": x_norm_cls_reg[:, 1:], | |
| "x_norm_patchtokens": x_norm_patch, | |
| "x_prenorm": x, | |
| "masks": masks, | |
| } | |
| ) | |
| return output | |
| def forward_features(self, x: Tensor | List[Tensor], masks: Optional[Tensor] = None) -> List[Dict[str, Tensor]]: | |
| if isinstance(x, torch.Tensor): | |
| return self.forward_features_list([x], [masks])[0] | |
| else: | |
| return self.forward_features_list(x, masks) | |
| def _get_intermediate_layers_not_chunked(self, x: Tensor, n: int = 1) -> List[Tensor]: | |
| x, hw_tuple = self.prepare_tokens_with_masks(x) | |
| if len(hw_tuple) == 3: | |
| D, H, W = hw_tuple | |
| rope_sincos = self.rope_embed_3D(D=D, H=H, W=W) | |
| else: | |
| H, W = hw_tuple | |
| rope_sincos = self.rope_embed_2D(H=H, W=W) | |
| # If n is an int, take the n last blocks. If it's a list, take them | |
| output, total_block_len = [], len(self.blocks) | |
| blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n | |
| for i, blk in enumerate(self.blocks): | |
| x = blk(x, rope_sincos) | |
| if i in blocks_to_take: | |
| output.append(x) | |
| assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" | |
| return output | |
| def get_intermediate_layers( | |
| self, | |
| x: torch.Tensor, | |
| *, | |
| n: Union[int, Sequence] = 1, # Layers or n last layers to take | |
| reshape: bool = False, | |
| return_class_token: bool = False, | |
| return_extra_tokens: bool = False, | |
| norm: bool = True, | |
| ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]]]: | |
| outputs = self._get_intermediate_layers_not_chunked(x, n) | |
| if norm: | |
| outputs_normed = [] | |
| for out in outputs: | |
| if self.untie_cls_and_patch_norms: | |
| x_norm_cls_reg = self.cls_norm(out[:, : self.n_storage_tokens + 1]) | |
| x_norm_patch = self.norm(out[:, self.n_storage_tokens + 1 :]) | |
| outputs_normed.append(torch.cat((x_norm_cls_reg, x_norm_patch), dim=1)) | |
| else: | |
| outputs_normed.append(self.norm(out)) | |
| outputs = outputs_normed | |
| class_tokens = [out[:, 0] for out in outputs] | |
| extra_tokens = [out[:, 1 : self.n_storage_tokens + 1] for out in outputs] | |
| outputs = [out[:, self.n_storage_tokens + 1 :] for out in outputs] | |
| if reshape: | |
| if x.dim() == 5: | |
| B, _, d, h, w = x.shape | |
| outputs = [ | |
| out.reshape(B, d // self.patch_size, h // self.patch_size, w // self.patch_size, -1).permute(0, 4, 1, 2, 3).contiguous() | |
| for out in outputs | |
| ] | |
| else: | |
| B, _, h, w = x.shape | |
| outputs = [ | |
| out.reshape(B, h // self.patch_size, w // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() | |
| for out in outputs | |
| ] | |
| if not return_class_token and not return_extra_tokens: | |
| return tuple(outputs) | |
| elif return_class_token and not return_extra_tokens: | |
| return tuple(zip(outputs, class_tokens)) | |
| elif not return_class_token and return_extra_tokens: | |
| return tuple(zip(outputs, extra_tokens)) | |
| elif return_class_token and return_extra_tokens: | |
| return tuple(zip(outputs, class_tokens, extra_tokens)) | |
| def forward(self, *args, is_training: bool = False, **kwargs) -> List[Dict[str, Tensor]] | Tensor: | |
| ret = self.forward_features(*args, **kwargs) | |
| if is_training: | |
| return ret | |
| else: | |
| return self.head(ret["x_norm_clstoken"]) | |
| def inflate_patch_embed3d_from_2d( | |
| self, | |
| mode: str = "avg", # "avg" or "center" | |
| ) -> None: | |
| """ | |
| Initialize PatchEmbed3D weights by inflating PatchEmbed (2D) weights. | |
| Args: | |
| pe2d: 2D patch embed module with Conv2d `proj` of shape [C_out, C_in, kH, kW]. | |
| pe3d: 3D patch embed module with Conv3d `proj` of shape [C_out, C_in, kD, kH, kW]. | |
| mode: | |
| - "avg": copy the 2D kernel into each temporal slice and divide by kD (I3D-style). | |
| - "center": copy into the center slice only; others set to 0. | |
| """ | |
| assert isinstance(self.patch_embed_2D.proj, nn.Conv2d) and isinstance(self.patch_embed_3D.proj, nn.Conv3d), \ | |
| "pe2d.proj must be Conv2d and pe3d.proj must be Conv3d" | |
| w2 = self.patch_embed_2D.proj.weight.data # [C_out, C_in, kH2, kW2] | |
| b2 = self.patch_embed_2D.proj.bias.data if self.patch_embed_2D.proj.bias is not None else None | |
| w3 = self.patch_embed_3D.proj.weight.data # [C_out, C_in, kD3, kH3, kW3] | |
| b3 = self.patch_embed_3D.proj.bias.data if self.patch_embed_3D.proj.bias is not None else None | |
| C_out2, C_in2, kH2, kW2 = w2.shape | |
| C_out3, C_in3, kD3, kH3, kW3 = w3.shape | |
| # Basic sanity checks | |
| assert C_out2 == C_out3, f"out_channels mismatch: {C_out2} vs {C_out3}" | |
| assert C_in2 == C_in3, f"in_channels mismatch: {C_in2} vs {C_in3}" | |
| assert kH2 == kH3 and kW2 == kW3, \ | |
| f"spatial kernel mismatch: (kH,kW)=({kH2},{kW2}) vs ({kH3},{kW3})" | |
| # Inflate: start from zeros | |
| w3.zero_() | |
| if mode == "avg": | |
| # Copy into every temporal slice and average across time | |
| # So the sum over temporal slices reproduces the 2D response. | |
| for t in range(kD3): | |
| w3[:, :, t, :, :].copy_(w2 / kD3) | |
| elif mode == "center": | |
| center = kD3 // 2 | |
| w3[:, :, center, :, :].copy_(w2) | |
| else: | |
| raise ValueError(f"Unknown mode='{mode}', expected 'avg' or 'center'.") | |
| # Copy bias if present (identical) | |
| if b2 is not None and b3 is not None: | |
| b3.copy_(b2) | |
| class Flexi_CT_Backbone(Flexi_CT_Core): | |
| def __init__( | |
| self, | |
| *, | |
| img_size: int = 224, | |
| patch_size: int = 16, | |
| in_chans: int = 1, | |
| pos_embed_rope_base: float = 100.0, | |
| pos_embed_rope_min_period: float | None = None, | |
| pos_embed_rope_max_period: float | None = None, | |
| pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", | |
| pos_embed_rope_shift_coords: float | None = None, | |
| pos_embed_rope_jitter_coords: float | None = None, | |
| pos_embed_rope_rescale_coords: float | None = None, | |
| pos_embed_rope_dtype: str = "bf16", | |
| embed_dim: int = 768, | |
| depth: int = 12, | |
| num_heads: int = 16, | |
| ffn_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| drop_path_rate: float = 0.0, | |
| layerscale_init: float | None = None, | |
| norm_layer: str = "layernorm", | |
| ffn_layer: str = "mlp", | |
| ffn_bias: bool = True, | |
| proj_bias: bool = True, | |
| n_storage_tokens: int = 0, | |
| mask_k_bias: bool = False, | |
| untie_cls_and_patch_norms: bool = False, | |
| untie_global_and_local_cls_norm: bool = False, | |
| device: Any | None = None, | |
| **ignored_kwargs, | |
| ): | |
| # Call parent class constructor with all required parameters | |
| super().__init__( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| pos_embed_rope_base=pos_embed_rope_base, | |
| pos_embed_rope_min_period=pos_embed_rope_min_period, | |
| pos_embed_rope_max_period=pos_embed_rope_max_period, | |
| pos_embed_rope_normalize_coords=pos_embed_rope_normalize_coords, | |
| pos_embed_rope_shift_coords=pos_embed_rope_shift_coords, | |
| pos_embed_rope_jitter_coords=pos_embed_rope_jitter_coords, | |
| pos_embed_rope_rescale_coords=pos_embed_rope_rescale_coords, | |
| pos_embed_rope_dtype=pos_embed_rope_dtype, | |
| embed_dim=embed_dim, | |
| depth=depth, | |
| num_heads=num_heads, | |
| ffn_ratio=ffn_ratio, | |
| qkv_bias=qkv_bias, | |
| drop_path_rate=drop_path_rate, | |
| layerscale_init=layerscale_init, | |
| norm_layer=norm_layer, | |
| ffn_layer=ffn_layer, | |
| ffn_bias=ffn_bias, | |
| proj_bias=proj_bias, | |
| n_storage_tokens=n_storage_tokens, | |
| mask_k_bias=mask_k_bias, | |
| untie_cls_and_patch_norms=untie_cls_and_patch_norms, | |
| untie_global_and_local_cls_norm=untie_global_and_local_cls_norm, | |
| device=device, | |
| **ignored_kwargs, | |
| ) | |
| self.patch_embed_2D = PatchEmbedND( | |
| dim = 2, | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| flatten_embedding=False, | |
| ) | |
| self.patch_embed_3D = PatchEmbedND( | |
| dim = 3, | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| flatten_embedding=False, | |
| ) | |
| class Flexi_CT_VLM_Module(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| vision_model: nn.Module, | |
| text_model: nn.Module, | |
| embed_dim = 1024, | |
| device: Any | None = None, | |
| **ignored_kwargs, | |
| ): | |
| super().__init__() | |
| # Call parent class constructor with all required parameters | |
| self.vision_model = vision_model | |
| self.text_model = text_model | |
| self.logit_scale = nn.Parameter(torch.empty(1)) | |
| self.vlm_embed_dim = embed_dim | |
| self.vlm_vision_projection = nn.Linear(2*vision_model.embed_dim, self.vlm_embed_dim, bias=False) | |
| self.device = device | |
| def forward(self, images: torch.Tensor, text) -> torch.Tensor: | |
| vision_features = self.vision_model(images, is_training = True) | |
| cls_token = vision_features["x_norm_clstoken"] # [B, D] | |
| patch_tokens = vision_features["x_norm_patchtokens"] # [B, P, D] | |
| # Mean pool patch tokens (like DINOTxt) | |
| mean_patch_token = torch.mean(patch_tokens, dim=1) # [B, D] | |
| # Concatenate CLS + mean(patch) along channel dimension (like DINOTxt) | |
| image_features = torch.cat([cls_token, mean_patch_token], dim=-1) # [B, 2*D] | |
| # Project vision features to VLM embedding space | |
| image_features = self.vlm_vision_projection(image_features) # [B, vlm_embed_dim] | |
| # Normalize image features | |
| image_features = F.normalize(image_features, dim=-1) | |
| text_features = self.text_model(**text) | |
| # Normalize text features | |
| text_features = F.normalize(text_features, dim=-1) | |
| return self.logit_scale.exp(), image_features, text_features | |
| def flexi_ct_backbone_base(patch_size=8, **kwargs): | |
| model = Flexi_CT_Backbone( | |
| patch_size=patch_size, | |
| embed_dim=864, | |
| depth=16, | |
| num_heads=12, | |
| ffn_ratio=4, | |
| **kwargs, | |
| ) | |
| return model | |