| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class MultiScaleFusion(nn.Module): |
| def __init__(self, in_ch: int, out_ch: int): |
| super().__init__() |
| self.scale1 = nn.Conv2d(in_ch, in_ch, 3, padding="same") |
| self.scale2 = nn.Conv2d(in_ch, in_ch, 3, padding="same", dilation=2) |
| self.scale3 = nn.Conv2d(in_ch, in_ch, 3, padding="same", dilation=4) |
| self.proj = nn.Conv2d(in_ch * 3, out_ch, 1, bias=False) if in_ch * 3 != out_ch else nn.Identity() |
|
|
| def forward(self, x): |
| f1 = self.scale1(x) |
| f2 = self.scale2(x) |
| f3 = self.scale3(x) |
| fused = torch.cat([f1, f2, f3], dim=1) |
| return self.proj(fused) |
|
|