import os import torch import torch.nn as nn import torch.nn.functional as F import timm import gma def get_encoder(): encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=False) ckpt = torch.load('tile_model.pth', map_location='cpu') ckpt = ckpt['tile_model'] ckpt = {k.replace("module.", ""): v for k, v in ckpt.items()} ckpt = {k.replace("backbone.", ""): v for k, v in ckpt.items()} encoder.load_state_dict(ckpt, strict=False) return encoder def get_aggregator(): aggregator = gma.GMA(ndim=1536, dropout=True) aggregator.load_state_dict(torch.load('slide_model.pth', map_location='cpu')['slide_model']) return aggregator class EAGLE( nn.Module, ): def __init__(self): super().__init__() self.encoder = get_encoder() self.aggregator = get_aggregator() def forward(self, x): h = self.encoder(x) A, _, output = self.aggregator(h) output = F.softmax(output, dim=1) return h, A, output