Fix: trading_intelligence/prediction_model.py - all 174 tests passing
Browse files
trading_intelligence/prediction_model.py
CHANGED
|
@@ -373,9 +373,9 @@ class MultiTaskLoss(nn.Module):
|
|
| 373 |
self.alpha_risk = alpha_risk
|
| 374 |
|
| 375 |
# Learned task uncertainty weights (Kendall et al.)
|
| 376 |
-
self.log_sigma_direction = nn.Parameter(torch.
|
| 377 |
-
self.log_sigma_return = nn.Parameter(torch.
|
| 378 |
-
self.log_sigma_risk = nn.Parameter(torch.
|
| 379 |
|
| 380 |
def forward(self, predictions: Dict[str, torch.Tensor],
|
| 381 |
targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
|
|
| 373 |
self.alpha_risk = alpha_risk
|
| 374 |
|
| 375 |
# Learned task uncertainty weights (Kendall et al.)
|
| 376 |
+
self.log_sigma_direction = nn.Parameter(torch.tensor(0.0))
|
| 377 |
+
self.log_sigma_return = nn.Parameter(torch.tensor(0.0))
|
| 378 |
+
self.log_sigma_risk = nn.Parameter(torch.tensor(0.0))
|
| 379 |
|
| 380 |
def forward(self, predictions: Dict[str, torch.Tensor],
|
| 381 |
targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|