| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from typing import Optional |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import BaseModelOutputWithPooling |
| import torch |
|
|
| from .configuration_tapct import TAPCTConfig |
| from .vision_transformer import vit_small, vit_base |
| from .vision_transformer_3d import vit_3d_small, vit_3d_base |
| from .vision_transformer_base import DinoVisionTransformerBase |
|
|
| class TAPCTPreTrainedModel(PreTrainedModel): |
| config_class = TAPCTConfig |
| base_model_prefix = "tapct" |
|
|
|
|
| class TAPCTModel(TAPCTPreTrainedModel): |
| """ |
| TAP-CT Vision Transformer model based on DINOv2: https://github.com/facebookresearch/dinov2. |
| |
| This model outputs raw hidden states and does not include any task-specific head. |
| """ |
| |
| def __init__(self, config: TAPCTConfig) -> None: |
| super().__init__(config) |
| self.config = config |
| self.model: DinoVisionTransformerBase |
|
|
| match config.model_variant: |
| case "2d": |
| if config.model_size == "small": |
| self.model = vit_small( |
| img_size=config.img_size, |
| patch_size=config.patch_size, |
| num_register_tokens=config.num_register_tokens, |
| in_chans=config.in_chans, |
| init_values=config.init_values, |
| block_chunks=config.block_chunks |
| ) |
| elif config.model_size == "base": |
| self.model = vit_base( |
| img_size=config.img_size, |
| patch_size=config.patch_size, |
| num_register_tokens=config.num_register_tokens, |
| in_chans=config.in_chans, |
| init_values=config.init_values, |
| block_chunks=config.block_chunks |
| ) |
| else: |
| raise ValueError(f"Model size '{config.model_size}' not supported for 2D") |
| |
| case "2.5d" | "3d": |
| if config.model_size == "small": |
| self.model = vit_3d_small( |
| img_size=config.img_size, |
| patch_size=config.patch_size, |
| num_register_tokens=config.num_register_tokens, |
| in_chans=config.in_chans, |
| init_values=config.init_values, |
| block_chunks=config.block_chunks |
| ) |
| elif config.model_size == "base": |
| self.model = vit_3d_base( |
| img_size=config.img_size, |
| patch_size=config.patch_size, |
| num_register_tokens=config.num_register_tokens, |
| in_chans=config.in_chans, |
| init_values=config.init_values, |
| block_chunks=config.block_chunks |
| ) |
| else: |
| raise ValueError(f"Model size '{config.model_size}' not supported for 3D") |
| |
| case _: |
| raise ValueError(f"Model variant '{config.model_variant}' not supported. Use '2d', '2.5d', or '3d'.") |
| |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| reshape: bool = False |
| ) -> BaseModelOutputWithPooling: |
| """ |
| Forward pass of the TAP-CT model. |
| |
| Parameters |
| ---------- |
| pixel_values : torch.Tensor |
| Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D. |
| output_hidden_states : Optional[bool], optional |
| Whether to return hidden states from all layers. |
| return_dict : Optional[bool], optional |
| Whether to return a ModelOutput instead of a plain tuple. |
| reshape : bool, default=False |
| Whether to reshape output features to spatial dimensions. If True, |
| returns shape (B, H, W, C) for 2D or (B, D, H, W, C) for 3D instead |
| of flattened (B, N, C) where N is the number of patches. |
| |
| Returns |
| ------- |
| BaseModelOutputWithPooling |
| Contains: |
| - last_hidden_state: Patch token features from the last layer |
| - pooler_output: CLS token from the last layer |
| - hidden_states: (optional) All hidden states if output_hidden_states=True |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
| if output_hidden_states: |
| outputs_tuple = self.model.get_intermediate_layers( |
| pixel_values, |
| n=self.model.n_blocks, |
| return_class_token=True, |
| reshape=reshape |
| ) |
| outputs = tuple(o[0] for o in outputs_tuple) |
| class_tokens = tuple(o[1] for o in outputs_tuple) |
| |
| last_hidden_state = outputs[-1] |
| pooler_output = class_tokens[-1] |
| hidden_states = outputs |
| else: |
| outputs_tuple = self.model.get_intermediate_layers( |
| pixel_values, |
| n=1, |
| return_class_token=True, |
| reshape=reshape |
| ) |
| last_hidden_state = outputs_tuple[0][0] |
| pooler_output = outputs_tuple[0][1] |
| hidden_states = None |
|
|
| if not return_dict: |
| return tuple( |
| v for v in [last_hidden_state, pooler_output, hidden_states] |
| if v is not None |
| ) |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=pooler_output, |
| hidden_states=hidden_states |
| ) |