griffingoodwin04 commited on
Commit
7d43b6a
·
1 Parent(s): 5eaed60

Update learning rate scheduler parameters and clean up loss calculation notes

Browse files
forecasting/models/vit_patch_model_local.py CHANGED
@@ -1,5 +1,4 @@
1
  from collections import deque
2
-
3
  import math
4
  import numpy as np
5
  import torch
@@ -51,7 +50,7 @@ class ViTLocal(pl.LightningModule):
51
 
52
  scheduler = CosineAnnealingWarmRestarts(
53
  optimizer,
54
- T_0=150,
55
  T_mult=2,
56
  eta_min=1e-7
57
  )
@@ -395,7 +394,6 @@ class SXRRegressionDynamicLoss:
395
 
396
  def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
397
  base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
398
- # base_loss = F.mse_loss(preds_norm, sxr_norm, reduction='none')
399
  weights = self._get_adaptive_weights(sxr_un)
400
  self._update_tracking(sxr_un, sxr_norm, preds_norm)
401
  weighted_loss = base_loss * weights
@@ -407,16 +405,16 @@ class SXRRegressionDynamicLoss:
407
 
408
  # Get continuous multipliers per class with custom params
409
  quiet_mult = self._get_performance_multiplier(
410
- self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet' # Was 0.2
411
  )
412
  c_mult = self._get_performance_multiplier(
413
- self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class' # Was 0.3
414
  )
415
  m_mult = self._get_performance_multiplier(
416
- self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class' # Was 0.4
417
  )
418
  x_mult = self._get_performance_multiplier(
419
- self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class' # Was 0.5
420
  )
421
 
422
  quiet_weight = self.base_weights['quiet'] * quiet_mult
 
1
  from collections import deque
 
2
  import math
3
  import numpy as np
4
  import torch
 
50
 
51
  scheduler = CosineAnnealingWarmRestarts(
52
  optimizer,
53
+ T_0=250,
54
  T_mult=2,
55
  eta_min=1e-7
56
  )
 
394
 
395
  def calculate_loss(self, preds_norm, sxr_norm, sxr_un):
396
  base_loss = F.huber_loss(preds_norm, sxr_norm, delta=.3, reduction='none')
 
397
  weights = self._get_adaptive_weights(sxr_un)
398
  self._update_tracking(sxr_un, sxr_norm, preds_norm)
399
  weighted_loss = base_loss * weights
 
405
 
406
  # Get continuous multipliers per class with custom params
407
  quiet_mult = self._get_performance_multiplier(
408
+ self.quiet_errors, max_multiplier=1.5, min_multiplier=0.6, sensitivity=0.05, sxrclass='quiet'
409
  )
410
  c_mult = self._get_performance_multiplier(
411
+ self.c_errors, max_multiplier=2, min_multiplier=0.7, sensitivity=0.08, sxrclass='c_class'
412
  )
413
  m_mult = self._get_performance_multiplier(
414
+ self.m_errors, max_multiplier=5.0, min_multiplier=0.8, sensitivity=0.1, sxrclass='m_class'
415
  )
416
  x_mult = self._get_performance_multiplier(
417
+ self.x_errors, max_multiplier=8.0, min_multiplier=0.8, sensitivity=0.12, sxrclass='x_class'
418
  )
419
 
420
  quiet_weight = self.base_weights['quiet'] * quiet_mult