Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files- worker/tasks.py +17 -3
worker/tasks.py
CHANGED
|
@@ -578,7 +578,21 @@ async def _execute_pipeline_stages_v2(
|
|
| 578 |
from deep_learning.config import get_tft_config
|
| 579 |
from pathlib import Path
|
| 580 |
|
| 581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
if ckpt.exists():
|
| 583 |
tft_report = generate_tft_analysis(session, "HG=F")
|
| 584 |
|
|
@@ -592,7 +606,7 @@ async def _execute_pipeline_stages_v2(
|
|
| 592 |
logger.warning(f"[run_id={run_id}] TFT prediction error: {tft_report.get('error')}")
|
| 593 |
else:
|
| 594 |
result["tft_snapshot_generated"] = False
|
| 595 |
-
logger.info(f"[run_id={run_id}] Stage 5.5 skipped: no TFT checkpoint found")
|
| 596 |
|
| 597 |
except ImportError:
|
| 598 |
result["tft_snapshot_generated"] = False
|
|
@@ -604,7 +618,7 @@ async def _execute_pipeline_stages_v2(
|
|
| 604 |
# -------------------------------------------------------------------------
|
| 605 |
# Stage 6: Generate commentary (only if snapshot was generated)
|
| 606 |
# -------------------------------------------------------------------------
|
| 607 |
-
if result.get("snapshot_generated") and snapshot_report:
|
| 608 |
logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
|
| 609 |
try:
|
| 610 |
from app.commentary import generate_and_save_commentary
|
|
|
|
| 578 |
from deep_learning.config import get_tft_config
|
| 579 |
from pathlib import Path
|
| 580 |
|
| 581 |
+
tft_cfg = get_tft_config()
|
| 582 |
+
ckpt = Path(tft_cfg.training.best_model_path)
|
| 583 |
+
|
| 584 |
+
# If checkpoint is not cached locally, try to pull from HF Hub first
|
| 585 |
+
if not ckpt.exists():
|
| 586 |
+
try:
|
| 587 |
+
from deep_learning.models.hub import download_tft_artifacts
|
| 588 |
+
logger.info(f"[run_id={run_id}] TFT checkpoint not found locally – attempting HF Hub download")
|
| 589 |
+
download_tft_artifacts(
|
| 590 |
+
local_dir=ckpt.parent,
|
| 591 |
+
repo_id=tft_cfg.training.hf_model_repo,
|
| 592 |
+
)
|
| 593 |
+
except Exception as hub_exc:
|
| 594 |
+
logger.warning(f"[run_id={run_id}] HF Hub download failed: {hub_exc}")
|
| 595 |
+
|
| 596 |
if ckpt.exists():
|
| 597 |
tft_report = generate_tft_analysis(session, "HG=F")
|
| 598 |
|
|
|
|
| 606 |
logger.warning(f"[run_id={run_id}] TFT prediction error: {tft_report.get('error')}")
|
| 607 |
else:
|
| 608 |
result["tft_snapshot_generated"] = False
|
| 609 |
+
logger.info(f"[run_id={run_id}] Stage 5.5 skipped: no TFT checkpoint found (train-tft workflow has not run yet)")
|
| 610 |
|
| 611 |
except ImportError:
|
| 612 |
result["tft_snapshot_generated"] = False
|
|
|
|
| 618 |
# -------------------------------------------------------------------------
|
| 619 |
# Stage 6: Generate commentary (only if snapshot was generated)
|
| 620 |
# -------------------------------------------------------------------------
|
| 621 |
+
if (result.get("snapshot_generated") and snapshot_report) or result.get("tft_snapshot_generated"):
|
| 622 |
logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
|
| 623 |
try:
|
| 624 |
from app.commentary import generate_and_save_commentary
|