| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from typing import Optional |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutputWithPooling |
| | import torch |
| |
|
| | from .configuration_tapct import TAPCTConfig |
| | from .vision_transformer import vit_small, vit_base |
| | from .vision_transformer_3d import vit_3d_small, vit_3d_base |
| | from .vision_transformer_base import DinoVisionTransformerBase |
| |
|
| | class TAPCTPreTrainedModel(PreTrainedModel): |
| | config_class = TAPCTConfig |
| | base_model_prefix = "tapct" |
| |
|
| |
|
| | class TAPCTModel(TAPCTPreTrainedModel): |
| | """ |
| | TAP-CT Vision Transformer model based on DINOv2: https://github.com/facebookresearch/dinov2. |
| | |
| | This model outputs raw hidden states and does not include any task-specific head. |
| | """ |
| | |
| | def __init__(self, config: TAPCTConfig) -> None: |
| | super().__init__(config) |
| | self.config = config |
| | self.model: DinoVisionTransformerBase |
| |
|
| | match config.model_variant: |
| | case "2d": |
| | if config.model_size == "small": |
| | self.model = vit_small( |
| | img_size=config.img_size, |
| | patch_size=config.patch_size, |
| | num_register_tokens=config.num_register_tokens, |
| | in_chans=config.in_chans, |
| | init_values=config.init_values, |
| | block_chunks=config.block_chunks |
| | ) |
| | elif config.model_size == "base": |
| | self.model = vit_base( |
| | img_size=config.img_size, |
| | patch_size=config.patch_size, |
| | num_register_tokens=config.num_register_tokens, |
| | in_chans=config.in_chans, |
| | init_values=config.init_values, |
| | block_chunks=config.block_chunks |
| | ) |
| | else: |
| | raise ValueError(f"Model size '{config.model_size}' not supported for 2D") |
| | |
| | case "2.5d" | "3d": |
| | if config.model_size == "small": |
| | self.model = vit_3d_small( |
| | img_size=config.img_size, |
| | patch_size=config.patch_size, |
| | num_register_tokens=config.num_register_tokens, |
| | in_chans=config.in_chans, |
| | init_values=config.init_values, |
| | block_chunks=config.block_chunks |
| | ) |
| | elif config.model_size == "base": |
| | self.model = vit_3d_base( |
| | img_size=config.img_size, |
| | patch_size=config.patch_size, |
| | num_register_tokens=config.num_register_tokens, |
| | in_chans=config.in_chans, |
| | init_values=config.init_values, |
| | block_chunks=config.block_chunks |
| | ) |
| | else: |
| | raise ValueError(f"Model size '{config.model_size}' not supported for 3D") |
| | |
| | case _: |
| | raise ValueError(f"Model variant '{config.model_variant}' not supported. Use '2d', '2.5d', or '3d'.") |
| | |
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | pixel_values: torch.Tensor, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | reshape: bool = False |
| | ) -> BaseModelOutputWithPooling: |
| | """ |
| | Forward pass of the TAP-CT model. |
| | |
| | Parameters |
| | ---------- |
| | pixel_values : torch.Tensor |
| | Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D. |
| | output_hidden_states : Optional[bool], optional |
| | Whether to return hidden states from all layers. |
| | return_dict : Optional[bool], optional |
| | Whether to return a ModelOutput instead of a plain tuple. |
| | reshape : bool, default=False |
| | Whether to reshape output features to spatial dimensions. If True, |
| | returns shape (B, H, W, C) for 2D or (B, D, H, W, C) for 3D instead |
| | of flattened (B, N, C) where N is the number of patches. |
| | |
| | Returns |
| | ------- |
| | BaseModelOutputWithPooling |
| | Contains: |
| | - last_hidden_state: Patch token features from the last layer |
| | - pooler_output: CLS token from the last layer |
| | - hidden_states: (optional) All hidden states if output_hidden_states=True |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| |
|
| | if output_hidden_states: |
| | outputs_tuple = self.model.get_intermediate_layers( |
| | pixel_values, |
| | n=self.model.n_blocks, |
| | return_class_token=True, |
| | reshape=reshape |
| | ) |
| | outputs = tuple(o[0] for o in outputs_tuple) |
| | class_tokens = tuple(o[1] for o in outputs_tuple) |
| | |
| | last_hidden_state = outputs[-1] |
| | pooler_output = class_tokens[-1] |
| | hidden_states = outputs |
| | else: |
| | outputs_tuple = self.model.get_intermediate_layers( |
| | pixel_values, |
| | n=1, |
| | return_class_token=True, |
| | reshape=reshape |
| | ) |
| | last_hidden_state = outputs_tuple[0][0] |
| | pooler_output = outputs_tuple[0][1] |
| | hidden_states = None |
| |
|
| | if not return_dict: |
| | return tuple( |
| | v for v in [last_hidden_state, pooler_output, hidden_states] |
| | if v is not None |
| | ) |
| |
|
| | return BaseModelOutputWithPooling( |
| | last_hidden_state=last_hidden_state, |
| | pooler_output=pooler_output, |
| | hidden_states=hidden_states |
| | ) |