File size: 1,661 Bytes
4a439c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class CompactGeoEmbed(nn.Module):
    def __init__(self, embed_c=32, proj_dim=96, pretrained=False):
        super().__init__()
        backbone = models.mobilenet_v2(
            weights=None if not pretrained else models.MobileNet_V2_Weights.IMAGENET1K_V1
        ).features
        self.backbone = backbone
        self.reduce = nn.Conv2d(1280, embed_c, 1)
        self.elev_conv = nn.Conv2d(1, embed_c, 3, padding=1)
        self.conv_head = nn.Sequential(
            nn.Conv2d(embed_c * 2, embed_c, 3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(embed_c, embed_c, 3, padding=1),
            nn.ReLU(True),
        )
        self.proj = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(embed_c, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, proj_dim),
        )
        self.risk_head = nn.Sequential(
            nn.Linear(proj_dim, proj_dim // 2),
            nn.ReLU(),
            nn.Linear(proj_dim // 2, 1),
            nn.Sigmoid(),
        )

    def forward(self, img, elev=None):
        x = self.backbone(img)
        x = self.reduce(x)
        if elev is None:
            elev = torch.zeros(x.size(0), 1, img.size(2), img.size(3), device=x.device)
        elev = F.interpolate(elev, size=x.shape[2:], mode="bilinear", align_corners=False)
        e = self.elev_conv(elev)
        x = self.conv_head(torch.cat([x, e], 1))
        p = self.proj(x)
        p = F.normalize(p, dim=1)
        risk = self.risk_head(p).squeeze(-1)
        return x, p, risk