ifieryarrows commited on
Commit
3b43982
·
verified ·
1 Parent(s): c271c72

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. deep_learning/training/hyperopt.py +86 -20
deep_learning/training/hyperopt.py CHANGED
@@ -38,6 +38,60 @@ from deep_learning.config import (
38
  logger = logging.getLogger(__name__)
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
42
  """Map an Optuna trial to a TFT-ASRO configuration."""
43
  model_cfg = TFTModelConfig(
@@ -414,26 +468,28 @@ def run_hyperopt(
414
  show_progress_bar=True,
415
  )
416
 
417
- best = study.best_trial
418
- logger.info("Optuna best trial #%d: val_loss=%.6f", best.number, best.value)
419
- logger.info("Best params: %s", best.params)
420
-
421
  # Save alongside best_tft_asro.ckpt (tft/ root) so upload_tft_artifacts picks it up.
422
  results_path = Path(base_cfg.training.best_model_path).parent / "optuna_results.json"
423
  results_path.parent.mkdir(parents=True, exist_ok=True)
424
- results_path.write_text(json.dumps({
425
- "best_trial": best.number,
426
- "best_value": best.value,
427
- "best_params": best.params,
428
- "n_trials": len(study.trials),
429
- }, indent=2))
430
 
431
- return {
432
- "best_trial": best.number,
433
- "best_value": best.value,
434
- "best_params": best.params,
435
- "n_trials": len(study.trials),
436
- }
 
 
 
 
 
 
 
 
 
 
437
 
438
 
439
  # ---------------------------------------------------------------------------
@@ -453,7 +509,17 @@ if __name__ == "__main__":
453
  print("\n" + "=" * 60)
454
  print("HYPEROPT COMPLETE")
455
  print("=" * 60)
456
- print(f"Best trial: #{result['best_trial']}")
457
- print(f"Best val_loss: {result['best_value']:.6f}")
458
- for k, v in result["best_params"].items():
459
- print(f" {k}: {v}")
 
 
 
 
 
 
 
 
 
 
 
38
  logger = logging.getLogger(__name__)
39
 
40
 
41
+ def _trial_state_counts(study) -> dict[str, int]:
42
+ """Return lowercase Optuna trial-state counts for logs and artifacts."""
43
+ counts: dict[str, int] = {}
44
+ for trial in study.trials:
45
+ state = getattr(trial.state, "name", str(trial.state)).lower()
46
+ counts[state] = counts.get(state, 0) + 1
47
+ return counts
48
+
49
+
50
+ def _best_finite_completed_trial(study):
51
+ """Optuna raises when no trial completed; select the usable best trial safely."""
52
+ completed = []
53
+ for trial in study.trials:
54
+ if getattr(trial.state, "name", None) != "COMPLETE":
55
+ continue
56
+ if trial.value is None or not np.isfinite(float(trial.value)):
57
+ continue
58
+ completed.append(trial)
59
+
60
+ if not completed:
61
+ return None
62
+
63
+ return min(completed, key=lambda trial: float(trial.value))
64
+
65
+
66
+ def _build_result_payload(study) -> dict:
67
+ """Build the persisted hyperopt artifact without assuming a best trial exists."""
68
+ trial_state_counts = _trial_state_counts(study)
69
+ best = _best_finite_completed_trial(study)
70
+
71
+ if best is None:
72
+ return {
73
+ "status": "no_finite_completed_trials",
74
+ "best_trial": None,
75
+ "best_value": None,
76
+ "best_params": {},
77
+ "n_trials": len(study.trials),
78
+ "trial_state_counts": trial_state_counts,
79
+ "message": (
80
+ "No Optuna trials completed with a finite objective value; "
81
+ "final training should use the default TFT-ASRO config."
82
+ ),
83
+ }
84
+
85
+ return {
86
+ "status": "completed",
87
+ "best_trial": best.number,
88
+ "best_value": float(best.value),
89
+ "best_params": best.params,
90
+ "n_trials": len(study.trials),
91
+ "trial_state_counts": trial_state_counts,
92
+ }
93
+
94
+
95
  def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
96
  """Map an Optuna trial to a TFT-ASRO configuration."""
97
  model_cfg = TFTModelConfig(
 
468
  show_progress_bar=True,
469
  )
470
 
 
 
 
 
471
  # Save alongside best_tft_asro.ckpt (tft/ root) so upload_tft_artifacts picks it up.
472
  results_path = Path(base_cfg.training.best_model_path).parent / "optuna_results.json"
473
  results_path.parent.mkdir(parents=True, exist_ok=True)
474
+ result = _build_result_payload(study)
475
+ results_path.write_text(json.dumps(result, indent=2, allow_nan=False))
 
 
 
 
476
 
477
+ if result["best_trial"] is None:
478
+ logger.warning(
479
+ "Optuna finished without a finite completed trial; state counts=%s. "
480
+ "Wrote fallback artifact to %s",
481
+ result["trial_state_counts"],
482
+ results_path,
483
+ )
484
+ else:
485
+ logger.info(
486
+ "Optuna best trial #%d: val_loss=%.6f",
487
+ result["best_trial"],
488
+ result["best_value"],
489
+ )
490
+ logger.info("Best params: %s", result["best_params"])
491
+
492
+ return result
493
 
494
 
495
  # ---------------------------------------------------------------------------
 
509
  print("\n" + "=" * 60)
510
  print("HYPEROPT COMPLETE")
511
  print("=" * 60)
512
+ if result["best_trial"] is None:
513
+ print(f"Status: {result['status']}")
514
+ print(result["message"])
515
+ if result.get("trial_state_counts"):
516
+ counts = ", ".join(
517
+ f"{state}={count}"
518
+ for state, count in sorted(result["trial_state_counts"].items())
519
+ )
520
+ print(f"Trial states: {counts}")
521
+ else:
522
+ print(f"Best trial: #{result['best_trial']}")
523
+ print(f"Best val_loss: {result['best_value']:.6f}")
524
+ for k, v in result["best_params"].items():
525
+ print(f" {k}: {v}")