| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| import gma | |
| def get_encoder(): | |
| encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=False) | |
| ckpt = torch.load('tile_model.pth', map_location='cpu') | |
| ckpt = ckpt['tile_model'] | |
| ckpt = {k.replace("module.", ""): v for k, v in ckpt.items()} | |
| ckpt = {k.replace("backbone.", ""): v for k, v in ckpt.items()} | |
| encoder.load_state_dict(ckpt, strict=False) | |
| return encoder | |
| def get_aggregator(): | |
| aggregator = gma.GMA(ndim=1536, dropout=True) | |
| aggregator.load_state_dict(torch.load('slide_model.pth', map_location='cpu')['slide_model']) | |
| return aggregator | |
| class EAGLE( | |
| nn.Module, | |
| ): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = get_encoder() | |
| self.aggregator = get_aggregator() | |
| def forward(self, x): | |
| h = self.encoder(x) | |
| A, _, output = self.aggregator(h) | |
| output = F.softmax(output, dim=1) | |
| return h, A, output | |