EAGLE / eagle.py
gabricampanella's picture
Upload folder using huggingface_hub
a423288 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import gma
def get_encoder():
encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=False)
ckpt = torch.load('tile_model.pth', map_location='cpu')
ckpt = ckpt['tile_model']
ckpt = {k.replace("module.", ""): v for k, v in ckpt.items()}
ckpt = {k.replace("backbone.", ""): v for k, v in ckpt.items()}
encoder.load_state_dict(ckpt, strict=False)
return encoder
def get_aggregator():
aggregator = gma.GMA(ndim=1536, dropout=True)
aggregator.load_state_dict(torch.load('slide_model.pth', map_location='cpu')['slide_model'])
return aggregator
class EAGLE(
nn.Module,
):
def __init__(self):
super().__init__()
self.encoder = get_encoder()
self.aggregator = get_aggregator()
def forward(self, x):
h = self.encoder(x)
A, _, output = self.aggregator(h)
output = F.softmax(output, dim=1)
return h, A, output