import torch import torch.nn as nn import torch.nn.functional as F def autopad(k, p=None): if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] return p class Conv(nn.Module): def __init__( self, c1, c2, k=1, s=1, p=None, g=1, act=True ): super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.SiLU() if act else nn.Identity() def forward(self, x): return self.act( self.bn( self.conv(x) ) ) class DSConv(nn.Module): def __init__( self, c1, c2, k=3, s=1, p=None, act=True ): super().__init__() self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False) self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.SiLU() if act else nn.Identity() def forward(self, x): return self.act( self.bn( self.pwconv( self.dwconv(x) ) ) ) class DS_Bottleneck(nn.Module): def __init__( self, c1, c2, k=3, shortcut=True ): super().__init__() self.dsconv1 = DSConv(c1, c1, k=3, s=1) self.dsconv2 = DSConv(c1, c2, k=k, s=1) self.shortcut = shortcut and c1 == c2 def forward(self, x): return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x)) class DS_C3k(nn.Module): def __init__( self, c1, c2, n=1, k=3, e=0.5 ): super().__init__() self.cv1 = Conv(c1, int(c2 * e), 1, 1) self.cv2 = Conv(c1, int(c2 * e), 1, 1) self.cv3 = Conv(2 * int(c2 * e), c2, 1, 1) self.m = nn.Sequential( *[ DS_Bottleneck( int(c2 * e), int(c2 * e), k=k, shortcut=True ) for _ in range(n) ] ) def forward(self, x): return self.cv3( torch.cat( (self.m(self.cv1(x)), self.cv2(x)), dim=1 ) ) class DS_C3k2(nn.Module): def __init__( self, c1, c2, n=1, k=3, e=0.5 ): super().__init__() self.cv1 = Conv(c1, int(c2 * e), 1, 1) self.m = DS_C3k(int(c2 * e), int(c2 * e), n=n, k=k, e=1.0) self.cv2 = Conv(int(c2 * e), c2, 1, 1) def forward(self, x): return self.cv2( self.m( self.cv1(x) ) ) class AdaptiveHyperedgeGeneration(nn.Module): def __init__( self, in_channels, num_hyperedges, num_heads ): super().__init__() self.num_hyperedges = num_hyperedges self.num_heads = num_heads self.head_dim = max(1, in_channels // num_heads) self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels)) self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False) self.query_proj = nn.Linear(in_channels, in_channels, bias=False) self.scale = self.head_dim ** -0.5 def forward(self, x): B, N, C = x.shape P = ( self.global_proto.unsqueeze(0) + self.context_mapper( torch.cat( ( F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1), F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1) ), dim=1 ) ).view(B, self.num_hyperedges, C)) return F.softmax(( (self.query_proj(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) @ P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)) * self.scale ).mean(dim=1).permute(0, 2, 1), dim=-1) class HypergraphConvolution(nn.Module): def __init__( self, in_channels, out_channels ): super().__init__() self.W_e = nn.Linear(in_channels, in_channels, bias=False) self.W_v = nn.Linear(in_channels, out_channels, bias=False) self.act = nn.SiLU() def forward(self, x, A): return x + self.act(self.W_v(A.transpose(1, 2).bmm(self.act(self.W_e(A.bmm(x)))))) class AdaptiveHypergraphComputation(nn.Module): def __init__( self, in_channels, out_channels, num_hyperedges, num_heads ): super().__init__() self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(in_channels, num_hyperedges, num_heads) self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels) def forward(self, x): B, _, H, W = x.shape x_flat = x.flatten(2).permute(0, 2, 1) return self.hypergraph_conv(x_flat, self.adaptive_hyperedge_gen(x_flat)).permute(0, 2, 1).view(B, -1, H, W) class C3AH(nn.Module): def __init__( self, c1, c2, num_hyperedges, num_heads, e=0.5 ): super().__init__() self.cv1 = Conv(c1, int(c1 * e), 1, 1) self.cv2 = Conv(c1, int(c1 * e), 1, 1) self.ahc = AdaptiveHypergraphComputation(int(c1 * e), int(c1 * e), num_hyperedges, num_heads) self.cv3 = Conv(2 * int(c1 * e), c2, 1, 1) def forward(self, x): return self.cv3( torch.cat( (self.ahc(self.cv2(x)), self.cv1(x)), dim=1 ) ) class HyperACE(nn.Module): def __init__( self, in_channels, out_channels, num_hyperedges=16, num_heads=8, k=2, l=1, c_h=0.5, c_l=0.25 ): super().__init__() c2, c3, c4, c5 = in_channels c_mid = c4 self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1) self.c_h = int(c_mid * c_h) self.c_l = int(c_mid * c_l) self.c_s = c_mid - self.c_h - self.c_l self.high_order_branch = nn.ModuleList([ C3AH( self.c_h, self.c_h, num_hyperedges=num_hyperedges, num_heads=num_heads, e=1.0 ) for _ in range(k) ]) self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1) self.low_order_branch = nn.Sequential( *[ DS_C3k( self.c_l, self.c_l, n=1, k=3, e=1.0 ) for _ in range(l) ] ) self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1) def forward(self, x): B2, B3, B4, B5 = x _, _, H4, W4 = B4.shape x_h, x_l, x_s = self.fuse_conv( torch.cat( ( F.interpolate( B2, size=(H4, W4), mode='bilinear', align_corners=False ), F.interpolate( B3, size=(H4, W4), mode='bilinear', align_corners=False ), B4, F.interpolate( B5, size=(H4, W4), mode='bilinear', align_corners=False ) ), dim=1 ) ).split([self.c_h, self.c_l, self.c_s], dim=1) return self.final_fuse( torch.cat( ( self.high_order_fuse(torch.cat([m(x_h) for m in self.high_order_branch], dim=1)), self.low_order_branch(x_l), x_s ), dim=1 ) ) class GatedFusion(nn.Module): def __init__( self, in_channels ): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1)) def forward(self, f_in, h): return f_in + self.gamma * h class YOLO13Encoder(nn.Module): def __init__( self, in_channels, base_channels=32 ): super().__init__() self.stem = DSConv( in_channels, base_channels, k=3, s=1 ) self.p2 = nn.Sequential( DSConv( base_channels, base_channels*2, k=3, s=(2, 2)), DS_C3k2( base_channels*2, base_channels*2, n=1 ) ) self.p3 = nn.Sequential( DSConv( base_channels*2, base_channels*4, k=3, s=(2, 2) ), DS_C3k2( base_channels*4, base_channels*4, n=2 ) ) self.p4 = nn.Sequential( DSConv( base_channels*4, base_channels*8, k=3, s=(2, 2) ), DS_C3k2( base_channels*8, base_channels*8, n=2 ) ) self.p5 = nn.Sequential( DSConv( base_channels*8, base_channels*16, k=3, s=(2, 2) ), DS_C3k2( base_channels*16, base_channels*16, n=1 ) ) self.out_channels = [base_channels*2, base_channels*4, base_channels*8, base_channels*16] def forward(self, x): p2 = self.p2(self.stem(x)) p3 = self.p3(p2) p4 = self.p4(p3) p5 = self.p5(p4) return [p2, p3, p4, p5] class YOLO13FullPADDecoder(nn.Module): def __init__(self, encoder_channels, hyperace_out_c, out_channels_final): super().__init__() c_p2, c_p3, c_p4, c_p5 = encoder_channels c_d5, c_d4, c_d3, c_d2 = c_p5, c_p4, c_p3, c_p2 self.h_to_d5 = Conv( hyperace_out_c, c_d5, 1, 1 ) self.h_to_d4 = Conv( hyperace_out_c, c_d4, 1, 1 ) self.h_to_d3 = Conv( hyperace_out_c, c_d3, 1, 1 ) self.h_to_d2 = Conv( hyperace_out_c, c_d2, 1, 1 ) self.fusion_d5 = GatedFusion(c_d5) self.fusion_d4 = GatedFusion(c_d4) self.fusion_d3 = GatedFusion(c_d3) self.fusion_d2 = GatedFusion(c_d2) self.skip_p5 = Conv( c_p5, c_d5, 1, 1 ) self.skip_p4 = Conv( c_p4, c_d4, 1, 1 ) self.skip_p3 = Conv( c_p3, c_d3, 1, 1 ) self.skip_p2 = Conv( c_p2, c_d2, 1, 1 ) self.up_d5 = DS_C3k2( c_d5, c_d4, n=1 ) self.up_d4 = DS_C3k2( c_d4, c_d3, n=1 ) self.up_d3 = DS_C3k2( c_d3, c_d2, n=1 ) self.final_d2 = DS_C3k2( c_d2, c_d2, n=1 ) self.final_conv = Conv( c_d2, out_channels_final, 1, 1 ) def forward(self, enc_feats, h_ace): p2, p3, p4, p5 = enc_feats d5 = self.skip_p5(p5) d4 = self.up_d5( F.interpolate( self.fusion_d5(d5, self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear', align_corners=False))), size=p4.shape[2:], mode='bilinear', align_corners=False ) ) + self.skip_p4(p4) d3 = self.up_d4( F.interpolate( self.fusion_d4(d4, self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear', align_corners=False))), size=p3.shape[2:], mode='bilinear', align_corners=False ) ) + self.skip_p3(p3) d2 = self.up_d3( F.interpolate( self.fusion_d3(d3, self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear', align_corners=False))), size=p2.shape[2:], mode='bilinear', align_corners=False ) ) + self.skip_p2(p2) return self.final_conv( self.final_d2( self.fusion_d2( d2, self.h_to_d2( F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear', align_corners=False) ) ) ) )