| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from collections import namedtuple |
| from typing import Optional |
|
|
| from einops import rearrange |
| from timm.models import VisionTransformer |
| import torch |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| from .eradio_model import eradio |
| from .radio_model import create_model_from_args |
| from .radio_model import RADIOModel as RADIOModelBase |
| from .input_conditioner import get_default_conditioner, InputConditioner |
|
|
|
|
| class RADIOConfig(PretrainedConfig): |
| """Pretrained Hugging Face configuration for RADIO models.""" |
|
|
| def __init__( |
| self, |
| args: Optional[dict] = None, |
| version: Optional[str] = "v1", |
| return_summary: Optional[bool] = True, |
| return_spatial_features: Optional[bool] = True, |
| **kwargs, |
| ): |
| self.args = args |
| self.version = version |
| self.return_summary = return_summary |
| self.return_spatial_features = return_spatial_features |
| super().__init__(**kwargs) |
|
|
|
|
| class RADIOModel(PreTrainedModel): |
| """Pretrained Hugging Face model for RADIO. |
| |
| This class inherits from PreTrainedModel, which provides |
| HuggingFace's functionality for loading and saving models. |
| """ |
|
|
| config_class = RADIOConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| RADIOArgs = namedtuple("RADIOArgs", config.args.keys()) |
| args = RADIOArgs(**config.args) |
| self.config = config |
| model = create_model_from_args(args) |
| input_conditioner: InputConditioner = get_default_conditioner() |
|
|
| self.radio_model = RADIOModelBase( |
| model, |
| input_conditioner, |
| config.return_summary, |
| config.return_spatial_features, |
| ) |
|
|
| @property |
| def model(self) -> VisionTransformer: |
| return self.radio_model.model |
|
|
| @property |
| def input_conditioner(self) -> InputConditioner: |
| return self.radio_model.input_conditioner |
|
|
| def forward(self, x: torch.Tensor): |
| return self.radio_model.forward(x) |
|
|
|
|
| class ERADIOConfig(PretrainedConfig): |
| """Pretrained Hugging Face configuration for ERADIO models.""" |
|
|
| def __init__( |
| self, |
| args: Optional[dict] = None, |
| version: Optional[str] = "v1", |
| return_summary: Optional[bool] = True, |
| return_spatial_features: Optional[bool] = True, |
| **kwargs, |
| ): |
| self.args = args |
| self.version = version |
| self.return_summary = return_summary |
| self.return_spatial_features = return_spatial_features |
| super().__init__(**kwargs) |
|
|
|
|
| class ERADIOModel(PreTrainedModel): |
| """Pretrained Hugging Face model for ERADIO. |
| |
| This class inherits from PreTrainedModel, which provides |
| HuggingFace's functionality for loading and saving models. |
| """ |
|
|
| config_class = ERADIOConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| config.args["in_chans"] = 3 |
| config.args["num_classes"] = 0 |
| config.args["return_full_features"] = config.return_spatial_features |
|
|
| self.config = config |
| model = eradio(**config.args) |
| self.input_conditioner: InputConditioner = get_default_conditioner() |
| self.return_summary = config.return_summary |
| self.return_spatial_features = config.return_spatial_features |
| self.model = model |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.input_conditioner(x) |
| y = self.model.forward_features(x) |
| summary, features = self.model.forward_features(x) |
|
|
| if isinstance(y, tuple): |
| summary, features = y |
| |
| features = rearrange(features, 'b c h w -> b (h w) c') |
| else: |
| summary = y |
| features = None |
|
|
| if self.return_summary and self.return_spatial_features: |
| return summary, features |
| elif self.return_summary: |
| return summary |
| return features |
|
|