Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files
deep_learning/training/hyperopt.py
CHANGED
|
@@ -38,6 +38,60 @@ from deep_learning.config import (
|
|
| 38 |
logger = logging.getLogger(__name__)
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
|
| 42 |
"""Map an Optuna trial to a TFT-ASRO configuration."""
|
| 43 |
model_cfg = TFTModelConfig(
|
|
@@ -414,26 +468,28 @@ def run_hyperopt(
|
|
| 414 |
show_progress_bar=True,
|
| 415 |
)
|
| 416 |
|
| 417 |
-
best = study.best_trial
|
| 418 |
-
logger.info("Optuna best trial #%d: val_loss=%.6f", best.number, best.value)
|
| 419 |
-
logger.info("Best params: %s", best.params)
|
| 420 |
-
|
| 421 |
# Save alongside best_tft_asro.ckpt (tft/ root) so upload_tft_artifacts picks it up.
|
| 422 |
results_path = Path(base_cfg.training.best_model_path).parent / "optuna_results.json"
|
| 423 |
results_path.parent.mkdir(parents=True, exist_ok=True)
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
"best_value": best.value,
|
| 427 |
-
"best_params": best.params,
|
| 428 |
-
"n_trials": len(study.trials),
|
| 429 |
-
}, indent=2))
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
|
| 439 |
# ---------------------------------------------------------------------------
|
|
@@ -453,7 +509,17 @@ if __name__ == "__main__":
|
|
| 453 |
print("\n" + "=" * 60)
|
| 454 |
print("HYPEROPT COMPLETE")
|
| 455 |
print("=" * 60)
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
logger = logging.getLogger(__name__)
|
| 39 |
|
| 40 |
|
| 41 |
+
def _trial_state_counts(study) -> dict[str, int]:
|
| 42 |
+
"""Return lowercase Optuna trial-state counts for logs and artifacts."""
|
| 43 |
+
counts: dict[str, int] = {}
|
| 44 |
+
for trial in study.trials:
|
| 45 |
+
state = getattr(trial.state, "name", str(trial.state)).lower()
|
| 46 |
+
counts[state] = counts.get(state, 0) + 1
|
| 47 |
+
return counts
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _best_finite_completed_trial(study):
|
| 51 |
+
"""Optuna raises when no trial completed; select the usable best trial safely."""
|
| 52 |
+
completed = []
|
| 53 |
+
for trial in study.trials:
|
| 54 |
+
if getattr(trial.state, "name", None) != "COMPLETE":
|
| 55 |
+
continue
|
| 56 |
+
if trial.value is None or not np.isfinite(float(trial.value)):
|
| 57 |
+
continue
|
| 58 |
+
completed.append(trial)
|
| 59 |
+
|
| 60 |
+
if not completed:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
return min(completed, key=lambda trial: float(trial.value))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _build_result_payload(study) -> dict:
|
| 67 |
+
"""Build the persisted hyperopt artifact without assuming a best trial exists."""
|
| 68 |
+
trial_state_counts = _trial_state_counts(study)
|
| 69 |
+
best = _best_finite_completed_trial(study)
|
| 70 |
+
|
| 71 |
+
if best is None:
|
| 72 |
+
return {
|
| 73 |
+
"status": "no_finite_completed_trials",
|
| 74 |
+
"best_trial": None,
|
| 75 |
+
"best_value": None,
|
| 76 |
+
"best_params": {},
|
| 77 |
+
"n_trials": len(study.trials),
|
| 78 |
+
"trial_state_counts": trial_state_counts,
|
| 79 |
+
"message": (
|
| 80 |
+
"No Optuna trials completed with a finite objective value; "
|
| 81 |
+
"final training should use the default TFT-ASRO config."
|
| 82 |
+
),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"status": "completed",
|
| 87 |
+
"best_trial": best.number,
|
| 88 |
+
"best_value": float(best.value),
|
| 89 |
+
"best_params": best.params,
|
| 90 |
+
"n_trials": len(study.trials),
|
| 91 |
+
"trial_state_counts": trial_state_counts,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
|
| 96 |
"""Map an Optuna trial to a TFT-ASRO configuration."""
|
| 97 |
model_cfg = TFTModelConfig(
|
|
|
|
| 468 |
show_progress_bar=True,
|
| 469 |
)
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
# Save alongside best_tft_asro.ckpt (tft/ root) so upload_tft_artifacts picks it up.
|
| 472 |
results_path = Path(base_cfg.training.best_model_path).parent / "optuna_results.json"
|
| 473 |
results_path.parent.mkdir(parents=True, exist_ok=True)
|
| 474 |
+
result = _build_result_payload(study)
|
| 475 |
+
results_path.write_text(json.dumps(result, indent=2, allow_nan=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
+
if result["best_trial"] is None:
|
| 478 |
+
logger.warning(
|
| 479 |
+
"Optuna finished without a finite completed trial; state counts=%s. "
|
| 480 |
+
"Wrote fallback artifact to %s",
|
| 481 |
+
result["trial_state_counts"],
|
| 482 |
+
results_path,
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
logger.info(
|
| 486 |
+
"Optuna best trial #%d: val_loss=%.6f",
|
| 487 |
+
result["best_trial"],
|
| 488 |
+
result["best_value"],
|
| 489 |
+
)
|
| 490 |
+
logger.info("Best params: %s", result["best_params"])
|
| 491 |
+
|
| 492 |
+
return result
|
| 493 |
|
| 494 |
|
| 495 |
# ---------------------------------------------------------------------------
|
|
|
|
| 509 |
print("\n" + "=" * 60)
|
| 510 |
print("HYPEROPT COMPLETE")
|
| 511 |
print("=" * 60)
|
| 512 |
+
if result["best_trial"] is None:
|
| 513 |
+
print(f"Status: {result['status']}")
|
| 514 |
+
print(result["message"])
|
| 515 |
+
if result.get("trial_state_counts"):
|
| 516 |
+
counts = ", ".join(
|
| 517 |
+
f"{state}={count}"
|
| 518 |
+
for state, count in sorted(result["trial_state_counts"].items())
|
| 519 |
+
)
|
| 520 |
+
print(f"Trial states: {counts}")
|
| 521 |
+
else:
|
| 522 |
+
print(f"Best trial: #{result['best_trial']}")
|
| 523 |
+
print(f"Best val_loss: {result['best_value']:.6f}")
|
| 524 |
+
for k, v in result["best_params"].items():
|
| 525 |
+
print(f" {k}: {v}")
|