| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from collections import namedtuple |
| from typing import Optional |
|
|
| from timm.models import VisionTransformer |
| import torch |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| from .model import create_model_from_args |
| from .input_conditioner import get_default_conditioner, InputConditioner |
|
|
|
|
| class RADIOConfig(PretrainedConfig): |
| """Pretrained Hugging Face configuration for RADIO models.""" |
|
|
| def __init__( |
| self, |
| args: Optional[dict] = None, |
| version: Optional[str] = "v1", |
| return_summary: Optional[bool] = True, |
| return_spatial_features: Optional[bool] = True, |
| **kwargs, |
| ): |
| self.args = args |
| self.version = version |
| self.return_summary = return_summary |
| self.return_spatial_features = return_spatial_features |
| super().__init__(**kwargs) |
|
|
|
|
| class RADIOModel(PreTrainedModel): |
| """Pretrained Hugging Face model for RADIO.""" |
|
|
| config_class = RADIOConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| RADIOArgs = namedtuple("RADIOArgs", config.args.keys()) |
| args = RADIOArgs(**config.args) |
| self.config = config |
| self.model = create_model_from_args(args) |
| self.input_conditioner: InputConditioner = get_default_conditioner() |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.input_conditioner(x) |
|
|
| y = self.model.forward_features(x) |
|
|
| if isinstance(y, (list, tuple)): |
| summary, all_feat = y |
| elif isinstance(self.model, VisionTransformer): |
| patch_gen = getattr(self.model, "patch_generator", None) |
| if patch_gen is not None: |
| summary = y[:, : patch_gen.num_cls_tokens].flatten(1) |
| all_feat = y[:, patch_gen.num_skip :] |
| elif self.model.global_pool == "avg": |
| summary = y[:, self.model.num_prefix_tokens :].mean(dim=1) |
| all_feat = y |
| else: |
| summary = y[:, 0] |
| all_feat = y[:, 1:] |
| else: |
| raise ValueError("Unsupported model type") |
|
|
| if self.config.return_summary and self.config.return_spatial_features: |
| return summary, all_feat |
| elif self.config.return_summary: |
| return summary |
| return all_feat |
|
|