3rd-hack-nation / src /model.py
Pingsz's picture
Upload 9 files
4a439c2 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class CompactGeoEmbed(nn.Module):
def __init__(self, embed_c=32, proj_dim=96, pretrained=False):
super().__init__()
backbone = models.mobilenet_v2(
weights=None if not pretrained else models.MobileNet_V2_Weights.IMAGENET1K_V1
).features
self.backbone = backbone
self.reduce = nn.Conv2d(1280, embed_c, 1)
self.elev_conv = nn.Conv2d(1, embed_c, 3, padding=1)
self.conv_head = nn.Sequential(
nn.Conv2d(embed_c * 2, embed_c, 3, padding=1),
nn.ReLU(True),
nn.Conv2d(embed_c, embed_c, 3, padding=1),
nn.ReLU(True),
)
self.proj = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_c, proj_dim),
nn.ReLU(),
nn.Linear(proj_dim, proj_dim),
)
self.risk_head = nn.Sequential(
nn.Linear(proj_dim, proj_dim // 2),
nn.ReLU(),
nn.Linear(proj_dim // 2, 1),
nn.Sigmoid(),
)
def forward(self, img, elev=None):
x = self.backbone(img)
x = self.reduce(x)
if elev is None:
elev = torch.zeros(x.size(0), 1, img.size(2), img.size(3), device=x.device)
elev = F.interpolate(elev, size=x.shape[2:], mode="bilinear", align_corners=False)
e = self.elev_conv(elev)
x = self.conv_head(torch.cat([x, e], 1))
p = self.proj(x)
p = F.normalize(p, dim=1)
risk = self.risk_head(p).squeeze(-1)
return x, p, risk