Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files- worker/tasks.py +28 -49
worker/tasks.py
CHANGED
|
@@ -119,11 +119,16 @@ async def run_pipeline(
|
|
| 119 |
Stage 1c: Cut-off calculation
|
| 120 |
Stage 1d: Price ingestion
|
| 121 |
Stage 2: Sentiment scoring
|
| 122 |
-
Stage 3: Sentiment aggregation
|
| 123 |
-
Stage
|
| 124 |
-
Stage
|
|
|
|
|
|
|
| 125 |
Stage 6: Commentary generation
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
Args:
|
| 128 |
ctx: arq context (contains redis connection)
|
| 129 |
run_id: Unique identifier for this run
|
|
@@ -502,39 +507,10 @@ async def _execute_pipeline_stages_v2(
|
|
| 502 |
result["model_trained"] = False
|
| 503 |
|
| 504 |
# -------------------------------------------------------------------------
|
| 505 |
-
# Stage 4.5: TFT-ASRO
|
|
|
|
| 506 |
# -------------------------------------------------------------------------
|
| 507 |
-
|
| 508 |
-
logger.info(f"[run_id={run_id}] Stage 4.5: TFT-ASRO training")
|
| 509 |
-
try:
|
| 510 |
-
from deep_learning.training.trainer import train_tft_model
|
| 511 |
-
|
| 512 |
-
tft_result = train_tft_model(use_asro=True)
|
| 513 |
-
|
| 514 |
-
result["tft_trained"] = True
|
| 515 |
-
result["tft_metrics"] = tft_result.get("test_metrics", {})
|
| 516 |
-
|
| 517 |
-
update_run_metrics(
|
| 518 |
-
session, run_id,
|
| 519 |
-
tft_trained=True,
|
| 520 |
-
tft_val_loss=tft_result.get("test_metrics", {}).get("mae"),
|
| 521 |
-
tft_sharpe=tft_result.get("test_metrics", {}).get("sharpe_ratio"),
|
| 522 |
-
tft_directional_accuracy=tft_result.get("test_metrics", {}).get("directional_accuracy"),
|
| 523 |
-
)
|
| 524 |
-
session.commit()
|
| 525 |
-
|
| 526 |
-
logger.info(f"[run_id={run_id}] TFT-ASRO training complete")
|
| 527 |
-
|
| 528 |
-
except ImportError:
|
| 529 |
-
logger.info(f"[run_id={run_id}] Stage 4.5 skipped: pytorch-forecasting not installed")
|
| 530 |
-
result["tft_trained"] = False
|
| 531 |
-
except Exception as e:
|
| 532 |
-
logger.warning(f"[run_id={run_id}] Stage 4.5 failed (non-critical): {e}")
|
| 533 |
-
result["tft_training_error"] = str(e)
|
| 534 |
-
result["tft_trained"] = False
|
| 535 |
-
session.rollback()
|
| 536 |
-
else:
|
| 537 |
-
result["tft_trained"] = False
|
| 538 |
|
| 539 |
# -------------------------------------------------------------------------
|
| 540 |
# Stage 5: Generate snapshot
|
|
@@ -616,31 +592,34 @@ async def _execute_pipeline_stages_v2(
|
|
| 616 |
session.rollback()
|
| 617 |
|
| 618 |
# -------------------------------------------------------------------------
|
| 619 |
-
# Stage 6: Generate commentary (
|
| 620 |
# -------------------------------------------------------------------------
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
| 622 |
logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
|
| 623 |
try:
|
| 624 |
from app.commentary import generate_and_save_commentary
|
| 625 |
-
|
| 626 |
-
|
| 627 |
await generate_and_save_commentary(
|
| 628 |
session=session,
|
| 629 |
symbol="HG=F",
|
| 630 |
-
current_price=
|
| 631 |
-
predicted_price=
|
| 632 |
-
predicted_return=
|
| 633 |
-
sentiment_index=
|
| 634 |
-
sentiment_label=
|
| 635 |
-
top_influencers=
|
| 636 |
-
news_count=
|
| 637 |
)
|
| 638 |
session.commit()
|
| 639 |
-
|
| 640 |
result["commentary_generated"] = True
|
| 641 |
update_run_metrics(session, run_id, commentary_generated=True)
|
| 642 |
session.commit()
|
| 643 |
-
|
| 644 |
except Exception as e:
|
| 645 |
logger.warning(f"[run_id={run_id}] Stage 6 failed: {e}")
|
| 646 |
result["commentary_generated"] = False
|
|
|
|
| 119 |
Stage 1c: Cut-off calculation
|
| 120 |
Stage 1d: Price ingestion
|
| 121 |
Stage 2: Sentiment scoring
|
| 122 |
+
Stage 3: Sentiment aggregation
|
| 123 |
+
Stage 3.5: FinBERT embedding extraction
|
| 124 |
+
Stage 4: XGBoost training (optional)
|
| 125 |
+
Stage 5: XGBoost snapshot
|
| 126 |
+
Stage 5.5: TFT-ASRO inference (downloads checkpoint from HF Hub)
|
| 127 |
Stage 6: Commentary generation
|
| 128 |
+
|
| 129 |
+
Note: TFT-ASRO full training is handled exclusively by the weekly
|
| 130 |
+
tft-training.yml GitHub workflow. The daily pipeline only runs inference.
|
| 131 |
+
|
| 132 |
Args:
|
| 133 |
ctx: arq context (contains redis connection)
|
| 134 |
run_id: Unique identifier for this run
|
|
|
|
| 507 |
result["model_trained"] = False
|
| 508 |
|
| 509 |
# -------------------------------------------------------------------------
|
| 510 |
+
# Stage 4.5: TFT-ASRO — inference only (training handled by weekly
|
| 511 |
+
# tft-training.yml workflow; daily pipeline never retrains TFT)
|
| 512 |
# -------------------------------------------------------------------------
|
| 513 |
+
result["tft_trained"] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
# -------------------------------------------------------------------------
|
| 516 |
# Stage 5: Generate snapshot
|
|
|
|
| 592 |
session.rollback()
|
| 593 |
|
| 594 |
# -------------------------------------------------------------------------
|
| 595 |
+
# Stage 6: Generate commentary (if any snapshot was generated)
|
| 596 |
# -------------------------------------------------------------------------
|
| 597 |
+
has_xgb_snapshot = result.get("snapshot_generated") and snapshot_report
|
| 598 |
+
has_tft_snapshot = result.get("tft_snapshot_generated")
|
| 599 |
+
|
| 600 |
+
if has_xgb_snapshot or has_tft_snapshot:
|
| 601 |
logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
|
| 602 |
try:
|
| 603 |
from app.commentary import generate_and_save_commentary
|
| 604 |
+
|
| 605 |
+
report = snapshot_report or {}
|
| 606 |
await generate_and_save_commentary(
|
| 607 |
session=session,
|
| 608 |
symbol="HG=F",
|
| 609 |
+
current_price=report.get("current_price", 0.0),
|
| 610 |
+
predicted_price=report.get("predicted_price", 0.0),
|
| 611 |
+
predicted_return=report.get("predicted_return", 0.0),
|
| 612 |
+
sentiment_index=report.get("sentiment_index", 0.0),
|
| 613 |
+
sentiment_label=report.get("sentiment_label", "Neutral"),
|
| 614 |
+
top_influencers=report.get("top_influencers", []),
|
| 615 |
+
news_count=report.get("data_quality", {}).get("news_count_7d", 0),
|
| 616 |
)
|
| 617 |
session.commit()
|
| 618 |
+
|
| 619 |
result["commentary_generated"] = True
|
| 620 |
update_run_metrics(session, run_id, commentary_generated=True)
|
| 621 |
session.commit()
|
| 622 |
+
|
| 623 |
except Exception as e:
|
| 624 |
logger.warning(f"[run_id={run_id}] Stage 6 failed: {e}")
|
| 625 |
result["commentary_generated"] = False
|