ifieryarrows commited on
Commit
c06ed8a
·
verified ·
1 Parent(s): dff0b7c

Sync from GitHub (tests passed)

Browse files
deep_learning/models/hub.py CHANGED
@@ -20,6 +20,7 @@ _HF_TOKEN_ENV = "HF_TOKEN"
20
  _ARTIFACTS = [
21
  "best_tft_asro.ckpt",
22
  "pca_finbert.joblib",
 
23
  ]
24
 
25
 
 
20
  _ARTIFACTS = [
21
  "best_tft_asro.ckpt",
22
  "pca_finbert.joblib",
23
+ "optuna_results.json",
24
  ]
25
 
26
 
deep_learning/training/hyperopt.py CHANGED
@@ -204,7 +204,8 @@ def run_hyperopt(
204
  logger.info("Optuna best trial #%d: val_loss=%.6f", best.number, best.value)
205
  logger.info("Best params: %s", best.params)
206
 
207
- results_path = Path(base_cfg.training.checkpoint_dir) / "optuna_results.json"
 
208
  results_path.parent.mkdir(parents=True, exist_ok=True)
209
  results_path.write_text(json.dumps({
210
  "best_trial": best.number,
 
204
  logger.info("Optuna best trial #%d: val_loss=%.6f", best.number, best.value)
205
  logger.info("Best params: %s", best.params)
206
 
207
+ # Save alongside best_tft_asro.ckpt (tft/ root) so upload_tft_artifacts picks it up.
208
+ results_path = Path(base_cfg.training.best_model_path).parent / "optuna_results.json"
209
  results_path.parent.mkdir(parents=True, exist_ok=True)
210
  results_path.write_text(json.dumps({
211
  "best_trial": best.number,
deep_learning/training/trainer.py CHANGED
@@ -286,7 +286,9 @@ def _apply_optuna_results(cfg: TFTASROConfig) -> TFTASROConfig:
286
  from dataclasses import replace
287
  from deep_learning.config import ASROConfig, TFTModelConfig, TrainingConfig
288
 
289
- results_path = Path(cfg.training.checkpoint_dir) / "optuna_results.json"
 
 
290
  if not results_path.exists():
291
  return cfg
292
 
 
286
  from dataclasses import replace
287
  from deep_learning.config import ASROConfig, TFTModelConfig, TrainingConfig
288
 
289
+ # optuna_results.json is saved at tft/ root (alongside best_tft_asro.ckpt),
290
+ # not inside the checkpoints/ subdirectory.
291
+ results_path = Path(cfg.training.best_model_path).parent / "optuna_results.json"
292
  if not results_path.exists():
293
  return cfg
294