avinashhm commited on
Commit
00ee97b
·
verified ·
1 Parent(s): 3701458

Fix: trading_intelligence/prediction_model.py - all 174 tests passing

Browse files
trading_intelligence/prediction_model.py CHANGED
@@ -373,9 +373,9 @@ class MultiTaskLoss(nn.Module):
373
  self.alpha_risk = alpha_risk
374
 
375
  # Learned task uncertainty weights (Kendall et al.)
376
- self.log_sigma_direction = nn.Parameter(torch.zeros(1))
377
- self.log_sigma_return = nn.Parameter(torch.zeros(1))
378
- self.log_sigma_risk = nn.Parameter(torch.zeros(1))
379
 
380
  def forward(self, predictions: Dict[str, torch.Tensor],
381
  targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
 
373
  self.alpha_risk = alpha_risk
374
 
375
  # Learned task uncertainty weights (Kendall et al.)
376
+ self.log_sigma_direction = nn.Parameter(torch.tensor(0.0))
377
+ self.log_sigma_return = nn.Parameter(torch.tensor(0.0))
378
+ self.log_sigma_risk = nn.Parameter(torch.tensor(0.0))
379
 
380
  def forward(self, predictions: Dict[str, torch.Tensor],
381
  targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: