ifieryarrows commited on
Commit
c23f275
·
verified ·
1 Parent(s): af722fe

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. worker/tasks.py +28 -49
worker/tasks.py CHANGED
@@ -119,11 +119,16 @@ async def run_pipeline(
119
  Stage 1c: Cut-off calculation
120
  Stage 1d: Price ingestion
121
  Stage 2: Sentiment scoring
122
- Stage 3: Sentiment aggregation
123
- Stage 4: Model training (optional)
124
- Stage 5: Snapshot generation
 
 
125
  Stage 6: Commentary generation
126
-
 
 
 
127
  Args:
128
  ctx: arq context (contains redis connection)
129
  run_id: Unique identifier for this run
@@ -502,39 +507,10 @@ async def _execute_pipeline_stages_v2(
502
  result["model_trained"] = False
503
 
504
  # -------------------------------------------------------------------------
505
- # Stage 4.5: TFT-ASRO training (optional, parallel to XGBoost)
 
506
  # -------------------------------------------------------------------------
507
- if train_model:
508
- logger.info(f"[run_id={run_id}] Stage 4.5: TFT-ASRO training")
509
- try:
510
- from deep_learning.training.trainer import train_tft_model
511
-
512
- tft_result = train_tft_model(use_asro=True)
513
-
514
- result["tft_trained"] = True
515
- result["tft_metrics"] = tft_result.get("test_metrics", {})
516
-
517
- update_run_metrics(
518
- session, run_id,
519
- tft_trained=True,
520
- tft_val_loss=tft_result.get("test_metrics", {}).get("mae"),
521
- tft_sharpe=tft_result.get("test_metrics", {}).get("sharpe_ratio"),
522
- tft_directional_accuracy=tft_result.get("test_metrics", {}).get("directional_accuracy"),
523
- )
524
- session.commit()
525
-
526
- logger.info(f"[run_id={run_id}] TFT-ASRO training complete")
527
-
528
- except ImportError:
529
- logger.info(f"[run_id={run_id}] Stage 4.5 skipped: pytorch-forecasting not installed")
530
- result["tft_trained"] = False
531
- except Exception as e:
532
- logger.warning(f"[run_id={run_id}] Stage 4.5 failed (non-critical): {e}")
533
- result["tft_training_error"] = str(e)
534
- result["tft_trained"] = False
535
- session.rollback()
536
- else:
537
- result["tft_trained"] = False
538
 
539
  # -------------------------------------------------------------------------
540
  # Stage 5: Generate snapshot
@@ -616,31 +592,34 @@ async def _execute_pipeline_stages_v2(
616
  session.rollback()
617
 
618
  # -------------------------------------------------------------------------
619
- # Stage 6: Generate commentary (only if snapshot was generated)
620
  # -------------------------------------------------------------------------
621
- if (result.get("snapshot_generated") and snapshot_report) or result.get("tft_snapshot_generated"):
 
 
 
622
  logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
623
  try:
624
  from app.commentary import generate_and_save_commentary
625
-
626
- # Extract required fields from snapshot and await async call
627
  await generate_and_save_commentary(
628
  session=session,
629
  symbol="HG=F",
630
- current_price=snapshot_report.get("current_price", 0.0),
631
- predicted_price=snapshot_report.get("predicted_price", 0.0),
632
- predicted_return=snapshot_report.get("predicted_return", 0.0),
633
- sentiment_index=snapshot_report.get("sentiment_index", 0.0),
634
- sentiment_label=snapshot_report.get("sentiment_label", "Neutral"),
635
- top_influencers=snapshot_report.get("top_influencers", []),
636
- news_count=snapshot_report.get("data_quality", {}).get("news_count_7d", 0),
637
  )
638
  session.commit()
639
-
640
  result["commentary_generated"] = True
641
  update_run_metrics(session, run_id, commentary_generated=True)
642
  session.commit()
643
-
644
  except Exception as e:
645
  logger.warning(f"[run_id={run_id}] Stage 6 failed: {e}")
646
  result["commentary_generated"] = False
 
119
  Stage 1c: Cut-off calculation
120
  Stage 1d: Price ingestion
121
  Stage 2: Sentiment scoring
122
+ Stage 3: Sentiment aggregation
123
+ Stage 3.5: FinBERT embedding extraction
124
+ Stage 4: XGBoost training (optional)
125
+ Stage 5: XGBoost snapshot
126
+ Stage 5.5: TFT-ASRO inference (downloads checkpoint from HF Hub)
127
  Stage 6: Commentary generation
128
+
129
+ Note: TFT-ASRO full training is handled exclusively by the weekly
130
+ tft-training.yml GitHub workflow. The daily pipeline only runs inference.
131
+
132
  Args:
133
  ctx: arq context (contains redis connection)
134
  run_id: Unique identifier for this run
 
507
  result["model_trained"] = False
508
 
509
  # -------------------------------------------------------------------------
510
+ # Stage 4.5: TFT-ASRO inference only (training handled by weekly
511
+ # tft-training.yml workflow; daily pipeline never retrains TFT)
512
  # -------------------------------------------------------------------------
513
+ result["tft_trained"] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  # -------------------------------------------------------------------------
516
  # Stage 5: Generate snapshot
 
592
  session.rollback()
593
 
594
  # -------------------------------------------------------------------------
595
+ # Stage 6: Generate commentary (if any snapshot was generated)
596
  # -------------------------------------------------------------------------
597
+ has_xgb_snapshot = result.get("snapshot_generated") and snapshot_report
598
+ has_tft_snapshot = result.get("tft_snapshot_generated")
599
+
600
+ if has_xgb_snapshot or has_tft_snapshot:
601
  logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
602
  try:
603
  from app.commentary import generate_and_save_commentary
604
+
605
+ report = snapshot_report or {}
606
  await generate_and_save_commentary(
607
  session=session,
608
  symbol="HG=F",
609
+ current_price=report.get("current_price", 0.0),
610
+ predicted_price=report.get("predicted_price", 0.0),
611
+ predicted_return=report.get("predicted_return", 0.0),
612
+ sentiment_index=report.get("sentiment_index", 0.0),
613
+ sentiment_label=report.get("sentiment_label", "Neutral"),
614
+ top_influencers=report.get("top_influencers", []),
615
+ news_count=report.get("data_quality", {}).get("news_count_7d", 0),
616
  )
617
  session.commit()
618
+
619
  result["commentary_generated"] = True
620
  update_run_metrics(session, run_id, commentary_generated=True)
621
  session.commit()
622
+
623
  except Exception as e:
624
  logger.warning(f"[run_id={run_id}] Stage 6 failed: {e}")
625
  result["commentary_generated"] = False