|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from .configuration_histaug import HistaugConfig |
|
|
from .histaug_model import HistaugModel |
|
|
|
|
|
class HistaugPretrainedModel(PreTrainedModel): |
|
|
config_class = HistaugConfig |
|
|
|
|
|
def __init__(self, config: HistaugConfig, *model_args, **model_kwargs): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.histaug = HistaugModel( |
|
|
input_dim=config.input_dim, |
|
|
depth=config.depth, |
|
|
num_heads=config.num_heads, |
|
|
mlp_ratio=config.mlp_ratio, |
|
|
use_transform_pos_embeddings=config.use_transform_pos_embeddings, |
|
|
positional_encoding_type=config.positional_encoding_type, |
|
|
final_activation=config.final_activation, |
|
|
embedding_type=config.embedding_type, |
|
|
chunk_size=config.chunk_size, |
|
|
transforms=config.transforms, |
|
|
**model_kwargs, |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.histaug.eval() |
|
|
for p in self.histaug.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
def forward(self, x: torch.Tensor, aug_params, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the histaug model. |
|
|
Args: |
|
|
x: Input tensor of shape (batch_size, input_dim) |
|
|
aug_params: Augmentation parameters dict as expected by HistaugModel |
|
|
""" |
|
|
return self.histaug(x, aug_params, **kwargs) |
|
|
|
|
|
def sample_aug_params( |
|
|
self, |
|
|
batch_size: int, |
|
|
device: torch.device = None, |
|
|
mode: str = "wsi_wise", |
|
|
): |
|
|
""" |
|
|
Proxy to HistaugModel.sample_aug_params |
|
|
""" |
|
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
return self.histaug.sample_aug_params(batch_size=batch_size, device=device, mode=mode) |
|
|
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
|
""" |
|
|
Save the model and configuration to the directory. |
|
|
""" |
|
|
super().save_pretrained(save_directory, **kwargs) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path: str, |
|
|
*model_args, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Load a model from a pretrained checkpoint. |
|
|
""" |
|
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |