ifieryarrows commited on
Commit
a1bedd7
·
verified ·
1 Parent(s): 622298f

Sync from GitHub (tests passed)

Browse files
deep_learning/models/losses.py CHANGED
@@ -219,7 +219,11 @@ class AdaptiveSharpeRatioLoss(nn.Module):
219
  # median for having lower variance than actual returns.
220
  # relu(1 - VR) fires when pred_std < actual_std; zero otherwise.
221
  median_std = median_pred.std() + self.sharpe_eps
222
- amplitude_loss = torch.relu(1.0 - median_std / actual_std)
 
 
 
 
223
 
224
  # --- Quantile (pinball) loss ---
225
  q_loss = self.quantile_loss(y_pred, y_actual)
 
219
  # median for having lower variance than actual returns.
220
  # relu(1 - VR) fires when pred_std < actual_std; zero otherwise.
221
  median_std = median_pred.std() + self.sharpe_eps
222
+ vr = median_std / actual_std
223
+ amplitude_loss = (
224
+ torch.relu(1.0 - vr) # under-variance: VR < 1 → strong penalty
225
+ + 0.25 * torch.relu(vr - 1.5) # over-variance: VR > 1.5 → gentle penalty
226
+ )
227
 
228
  # --- Quantile (pinball) loss ---
229
  q_loss = self.quantile_loss(y_pred, y_actual)
deep_learning/models/tft_copper.py CHANGED
@@ -82,7 +82,11 @@ try:
82
 
83
  # Median amplitude: penalise if median pred variance < actual variance
84
  median_std = median_pred.std() + self.sharpe_eps
85
- amplitude_loss = torch.relu(1.0 - median_std / actual_std)
 
 
 
 
86
 
87
  # Quantile (pinball) loss via parent — covers all 7 quantile bands
88
  q_loss = super().loss(y_pred, target)
 
82
 
83
  # Median amplitude: penalise if median pred variance < actual variance
84
  median_std = median_pred.std() + self.sharpe_eps
85
+ vr = median_std / actual_std
86
+ amplitude_loss = (
87
+ torch.relu(1.0 - vr) # under-variance: VR < 1 → strong penalty
88
+ + 0.25 * torch.relu(vr - 1.5) # over-variance: VR > 1.5 → gentle penalty
89
+ )
90
 
91
  # Quantile (pinball) loss via parent — covers all 7 quantile bands
92
  q_loss = super().loss(y_pred, target)