yamero999 commited on
Commit
3c3aff4
·
verified ·
1 Parent(s): 8c1c6db

🔧 Fix JIT compilation error in Mish activation function

Browse files
Files changed (1) hide show
  1. breakthrough_unet_v6_simple.py +9 -20
breakthrough_unet_v6_simple.py CHANGED
@@ -104,28 +104,17 @@ class EnhancedMultiScaleFusion(nn.Module):
104
  feat_1x1 = self.mish_1x1(self.conv_1x1(x))
105
  feat_3x3 = self.mish_3x3(self.conv_3x3(x))
106
 
107
- # 🔬 RESEARCH: Conditional processing for CPU efficiency
108
- if self.training or H > 256: # Full processing for training or large images
109
- feat_5x5 = self.mish_5x5(self.conv_5x5(x))
110
- feat_dilated = self.mish_dilated(self.conv_dilated(x))
111
 
112
- # Global context with optimized interpolation
113
- global_feat = self.global_pool(x)
114
- global_conv = self.mish_global(self.global_conv(global_feat))
115
- global_upsampled = F.interpolate(global_conv, size=(H, W), mode='bilinear', align_corners=False)
116
-
117
- # Full feature concatenation
118
- all_features = torch.cat([feat_1x1, feat_3x3, feat_5x5, feat_dilated, global_upsampled], dim=1)
119
- else:
120
- # 🔬 RESEARCH: Simplified processing for CPU inference on smaller images
121
- feat_dilated = self.mish_dilated(self.conv_dilated(x))
122
-
123
- # Skip expensive 5x5 conv and global context for speed
124
- # Pad to maintain expected channel count
125
- feat_5x5_placeholder = torch.zeros_like(feat_1x1)
126
- global_placeholder = torch.zeros_like(feat_1x1)
127
 
128
- all_features = torch.cat([feat_1x1, feat_3x3, feat_5x5_placeholder, feat_dilated, global_placeholder], dim=1)
 
129
 
130
  # Fuse features
131
  fused = self.fusion(all_features)
 
104
  feat_1x1 = self.mish_1x1(self.conv_1x1(x))
105
  feat_3x3 = self.mish_3x3(self.conv_3x3(x))
106
 
107
+ # 🔬 RESEARCH: Always use full processing for stability (conditional optimization removed)
108
+ feat_5x5 = self.mish_5x5(self.conv_5x5(x))
109
+ feat_dilated = self.mish_dilated(self.conv_dilated(x))
 
110
 
111
+ # Global context
112
+ global_feat = self.global_pool(x)
113
+ global_conv = self.mish_global(self.global_conv(global_feat))
114
+ global_upsampled = F.interpolate(global_conv, size=(H, W), mode='bilinear', align_corners=False)
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # Concatenate all features
117
+ all_features = torch.cat([feat_1x1, feat_3x3, feat_5x5, feat_dilated, global_upsampled], dim=1)
118
 
119
  # Fuse features
120
  fused = self.fusion(all_features)