histaug-conch / modeling_histaug.py
sofieneb's picture
Upload folder using huggingface_hub
91eed61 verified
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from .configuration_histaug import HistaugConfig
from .histaug_model import HistaugModel
class HistaugPretrainedModel(PreTrainedModel):
config_class = HistaugConfig
def __init__(self, config: HistaugConfig, *model_args, **model_kwargs):
super().__init__(config)
# instantiate your core model using values from the config
self.histaug = HistaugModel(
input_dim=config.input_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
use_transform_pos_embeddings=config.use_transform_pos_embeddings,
positional_encoding_type=config.positional_encoding_type,
final_activation=config.final_activation,
embedding_type=config.embedding_type,
chunk_size=config.chunk_size,
transforms=config.transforms,
**model_kwargs,
)
self.post_init()
self.histaug.eval()
for p in self.histaug.parameters():
p.requires_grad = False
def forward(self, x: torch.Tensor, aug_params, **kwargs) -> torch.Tensor:
"""
Forward pass through the histaug model.
Args:
x: Input tensor of shape (batch_size, input_dim)
aug_params: Augmentation parameters dict as expected by HistaugModel
"""
return self.histaug(x, aug_params, **kwargs)
def sample_aug_params(
self,
batch_size: int,
device: torch.device = None,
mode: str = "wsi_wise",
):
"""
Proxy to HistaugModel.sample_aug_params
"""
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
return self.histaug.sample_aug_params(batch_size=batch_size, device=device, mode=mode)
def save_pretrained(self, save_directory: str, **kwargs):
"""
Save the model and configuration to the directory.
"""
super().save_pretrained(save_directory, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
*model_args,
**kwargs,
):
"""
Load a model from a pretrained checkpoint.
"""
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)