import torch import torch.nn as nn import torch.nn.functional as F class MultiScaleFusion(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() self.scale1 = nn.Conv2d(in_ch, in_ch, 3, padding="same") self.scale2 = nn.Conv2d(in_ch, in_ch, 3, padding="same", dilation=2) self.scale3 = nn.Conv2d(in_ch, in_ch, 3, padding="same", dilation=4) self.proj = nn.Conv2d(in_ch * 3, out_ch, 1, bias=False) if in_ch * 3 != out_ch else nn.Identity() def forward(self, x): f1 = self.scale1(x) f2 = self.scale2(x) f3 = self.scale3(x) fused = torch.cat([f1, f2, f3], dim=1) return self.proj(fused)