Update doclayout_yolo/utils/callbacks/raytune.py
Browse files
doclayout_yolo/utils/callbacks/raytune.py
CHANGED
|
@@ -7,19 +7,17 @@ try:
|
|
| 7 |
import ray
|
| 8 |
from ray import tune
|
| 9 |
from ray.air import session
|
| 10 |
-
|
| 11 |
except (ImportError, AssertionError):
|
| 12 |
tune = None
|
| 13 |
-
|
| 14 |
|
| 15 |
def on_fit_epoch_end(trainer):
|
| 16 |
"""Sends training metrics to Ray Tune at end of each epoch."""
|
| 17 |
-
if
|
| 18 |
metrics = trainer.metrics
|
| 19 |
metrics["epoch"] = trainer.epoch
|
| 20 |
session.report(metrics)
|
| 21 |
|
| 22 |
-
|
| 23 |
callbacks = (
|
| 24 |
{
|
| 25 |
"on_fit_epoch_end": on_fit_epoch_end,
|
|
|
|
| 7 |
import ray
|
| 8 |
from ray import tune
|
| 9 |
from ray.air import session
|
|
|
|
| 10 |
except (ImportError, AssertionError):
|
| 11 |
tune = None
|
| 12 |
+
session = None
|
| 13 |
|
| 14 |
def on_fit_epoch_end(trainer):
|
| 15 |
"""Sends training metrics to Ray Tune at end of each epoch."""
|
| 16 |
+
if session and session.get_session(): # ✅ this is the correct Ray way now
|
| 17 |
metrics = trainer.metrics
|
| 18 |
metrics["epoch"] = trainer.epoch
|
| 19 |
session.report(metrics)
|
| 20 |
|
|
|
|
| 21 |
callbacks = (
|
| 22 |
{
|
| 23 |
"on_fit_epoch_end": on_fit_epoch_end,
|