File size: 2,408 Bytes
91eed61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel

from .configuration_histaug import HistaugConfig
from .histaug_model import HistaugModel

class HistaugPretrainedModel(PreTrainedModel):
    config_class = HistaugConfig

    def __init__(self, config: HistaugConfig, *model_args, **model_kwargs):
        super().__init__(config)

        # instantiate your core model using values from the config
        self.histaug = HistaugModel(
            input_dim=config.input_dim,
            depth=config.depth,
            num_heads=config.num_heads,
            mlp_ratio=config.mlp_ratio,
            use_transform_pos_embeddings=config.use_transform_pos_embeddings,
            positional_encoding_type=config.positional_encoding_type,
            final_activation=config.final_activation,
            embedding_type=config.embedding_type,
            chunk_size=config.chunk_size,
            transforms=config.transforms,
            **model_kwargs,
        )

        self.post_init()

        self.histaug.eval()
        for p in self.histaug.parameters():
            p.requires_grad = False

    def forward(self, x: torch.Tensor, aug_params, **kwargs) -> torch.Tensor:
        """
        Forward pass through the histaug model.
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            aug_params: Augmentation parameters dict as expected by HistaugModel
        """
        return self.histaug(x, aug_params, **kwargs)

    def sample_aug_params(
        self,
        batch_size: int,
        device: torch.device = None,
        mode: str = "wsi_wise",
    ):
        """
        Proxy to HistaugModel.sample_aug_params
        """
        device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return self.histaug.sample_aug_params(batch_size=batch_size, device=device, mode=mode)

    def save_pretrained(self, save_directory: str, **kwargs):
        """
        Save the model and configuration to the directory.
        """
        super().save_pretrained(save_directory, **kwargs)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        *model_args,
        **kwargs,
    ):
        """
        Load a model from a pretrained checkpoint.
        """
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)