File size: 1,029 Bytes
a423288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import gma

def get_encoder():
    encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=False)
    ckpt = torch.load('tile_model.pth', map_location='cpu')
    ckpt = ckpt['tile_model']
    ckpt = {k.replace("module.", ""): v for k, v in ckpt.items()}
    ckpt = {k.replace("backbone.", ""): v for k, v in ckpt.items()}
    encoder.load_state_dict(ckpt, strict=False)
    return encoder

def get_aggregator():
    aggregator = gma.GMA(ndim=1536, dropout=True)
    aggregator.load_state_dict(torch.load('slide_model.pth', map_location='cpu')['slide_model'])
    return aggregator

class EAGLE(
        nn.Module,
):
    def __init__(self):
        super().__init__()
        self.encoder = get_encoder()
        self.aggregator = get_aggregator()
    
    def forward(self, x):
        h = self.encoder(x)
        A, _, output = self.aggregator(h)
        output = F.softmax(output, dim=1)
        return h, A, output