🔧 Fix JIT compilation error in Mish activation function
Browse files
breakthrough_unet_v6_simple.py
CHANGED
|
@@ -104,28 +104,17 @@ class EnhancedMultiScaleFusion(nn.Module):
|
|
| 104 |
feat_1x1 = self.mish_1x1(self.conv_1x1(x))
|
| 105 |
feat_3x3 = self.mish_3x3(self.conv_3x3(x))
|
| 106 |
|
| 107 |
-
# 🔬 RESEARCH:
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
feat_dilated = self.mish_dilated(self.conv_dilated(x))
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# Full feature concatenation
|
| 118 |
-
all_features = torch.cat([feat_1x1, feat_3x3, feat_5x5, feat_dilated, global_upsampled], dim=1)
|
| 119 |
-
else:
|
| 120 |
-
# 🔬 RESEARCH: Simplified processing for CPU inference on smaller images
|
| 121 |
-
feat_dilated = self.mish_dilated(self.conv_dilated(x))
|
| 122 |
-
|
| 123 |
-
# Skip expensive 5x5 conv and global context for speed
|
| 124 |
-
# Pad to maintain expected channel count
|
| 125 |
-
feat_5x5_placeholder = torch.zeros_like(feat_1x1)
|
| 126 |
-
global_placeholder = torch.zeros_like(feat_1x1)
|
| 127 |
|
| 128 |
-
|
|
|
|
| 129 |
|
| 130 |
# Fuse features
|
| 131 |
fused = self.fusion(all_features)
|
|
|
|
| 104 |
feat_1x1 = self.mish_1x1(self.conv_1x1(x))
|
| 105 |
feat_3x3 = self.mish_3x3(self.conv_3x3(x))
|
| 106 |
|
| 107 |
+
# 🔬 RESEARCH: Always use full processing for stability (conditional optimization removed)
|
| 108 |
+
feat_5x5 = self.mish_5x5(self.conv_5x5(x))
|
| 109 |
+
feat_dilated = self.mish_dilated(self.conv_dilated(x))
|
|
|
|
| 110 |
|
| 111 |
+
# Global context
|
| 112 |
+
global_feat = self.global_pool(x)
|
| 113 |
+
global_conv = self.mish_global(self.global_conv(global_feat))
|
| 114 |
+
global_upsampled = F.interpolate(global_conv, size=(H, W), mode='bilinear', align_corners=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
# Concatenate all features
|
| 117 |
+
all_features = torch.cat([feat_1x1, feat_3x3, feat_5x5, feat_dilated, global_upsampled], dim=1)
|
| 118 |
|
| 119 |
# Fuse features
|
| 120 |
fused = self.fusion(all_features)
|