ifieryarrows commited on
Commit
bb3710e
·
verified ·
1 Parent(s): 1f6d359

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. worker/tasks.py +17 -3
worker/tasks.py CHANGED
@@ -578,7 +578,21 @@ async def _execute_pipeline_stages_v2(
578
  from deep_learning.config import get_tft_config
579
  from pathlib import Path
580
 
581
- ckpt = Path(get_tft_config().training.best_model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  if ckpt.exists():
583
  tft_report = generate_tft_analysis(session, "HG=F")
584
 
@@ -592,7 +606,7 @@ async def _execute_pipeline_stages_v2(
592
  logger.warning(f"[run_id={run_id}] TFT prediction error: {tft_report.get('error')}")
593
  else:
594
  result["tft_snapshot_generated"] = False
595
- logger.info(f"[run_id={run_id}] Stage 5.5 skipped: no TFT checkpoint found")
596
 
597
  except ImportError:
598
  result["tft_snapshot_generated"] = False
@@ -604,7 +618,7 @@ async def _execute_pipeline_stages_v2(
604
  # -------------------------------------------------------------------------
605
  # Stage 6: Generate commentary (only if snapshot was generated)
606
  # -------------------------------------------------------------------------
607
- if result.get("snapshot_generated") and snapshot_report:
608
  logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
609
  try:
610
  from app.commentary import generate_and_save_commentary
 
578
  from deep_learning.config import get_tft_config
579
  from pathlib import Path
580
 
581
+ tft_cfg = get_tft_config()
582
+ ckpt = Path(tft_cfg.training.best_model_path)
583
+
584
+ # If checkpoint is not cached locally, try to pull from HF Hub first
585
+ if not ckpt.exists():
586
+ try:
587
+ from deep_learning.models.hub import download_tft_artifacts
588
+ logger.info(f"[run_id={run_id}] TFT checkpoint not found locally – attempting HF Hub download")
589
+ download_tft_artifacts(
590
+ local_dir=ckpt.parent,
591
+ repo_id=tft_cfg.training.hf_model_repo,
592
+ )
593
+ except Exception as hub_exc:
594
+ logger.warning(f"[run_id={run_id}] HF Hub download failed: {hub_exc}")
595
+
596
  if ckpt.exists():
597
  tft_report = generate_tft_analysis(session, "HG=F")
598
 
 
606
  logger.warning(f"[run_id={run_id}] TFT prediction error: {tft_report.get('error')}")
607
  else:
608
  result["tft_snapshot_generated"] = False
609
+ logger.info(f"[run_id={run_id}] Stage 5.5 skipped: no TFT checkpoint found (train-tft workflow has not run yet)")
610
 
611
  except ImportError:
612
  result["tft_snapshot_generated"] = False
 
618
  # -------------------------------------------------------------------------
619
  # Stage 6: Generate commentary (only if snapshot was generated)
620
  # -------------------------------------------------------------------------
621
+ if (result.get("snapshot_generated") and snapshot_report) or result.get("tft_snapshot_generated"):
622
  logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
623
  try:
624
  from app.commentary import generate_and_save_commentary