""" Worker tasks for arq. This module defines the tasks that the worker executes. The main task is `run_pipeline` which orchestrates the entire pipeline. Faz 2: Integrated news_raw/news_processed pipeline with proper commit boundaries, metrics tracking, and degraded mode handling. """ import logging import os import socket import uuid from datetime import datetime, timezone from typing import Any, Optional from sqlalchemy.orm import Session # These imports will be updated as we refactor import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.db import SessionLocal, init_db, get_db_type from app.settings import get_settings from app.models import PipelineRunMetrics from adapters.db.lock import ( PIPELINE_LOCK_KEY, try_acquire_lock, release_lock, write_lock_visibility, clear_lock_visibility, ) logger = logging.getLogger(__name__) # ============================================================================= # Helper functions for metrics tracking # ============================================================================= def create_run_metrics( session: Session, run_id: str, started_at: datetime, ) -> PipelineRunMetrics: """Create initial pipeline_run_metrics record.""" metrics = PipelineRunMetrics( run_id=run_id, run_started_at=started_at, status="running", ) session.add(metrics) session.flush() return metrics def update_run_metrics( session: Session, run_id: str, **kwargs, ) -> None: """Update pipeline_run_metrics with new values.""" metrics = session.query(PipelineRunMetrics).filter( PipelineRunMetrics.run_id == run_id ).first() if metrics: for key, value in kwargs.items(): if hasattr(metrics, key): setattr(metrics, key, value) session.flush() def finalize_run_metrics( session: Session, run_id: str, status: str, quality_state: str = "ok", error_message: Optional[str] = None, ) -> None: """Finalize run metrics with completion status.""" completed_at = datetime.now(timezone.utc) metrics = session.query(PipelineRunMetrics).filter( PipelineRunMetrics.run_id == run_id ).first() if metrics: metrics.run_completed_at = completed_at metrics.status = status metrics.quality_state = quality_state if metrics.run_started_at: metrics.duration_seconds = (completed_at - metrics.run_started_at).total_seconds() if error_message: metrics.error_message = error_message session.flush() # ============================================================================= # Main pipeline task # ============================================================================= async def run_pipeline( ctx: dict, run_id: str, train_model: bool = False, trigger_source: str = "unknown", enqueued_at: str = None, ) -> dict: """ Main pipeline task - executed by arq worker. This is the ONLY entrypoint for pipeline execution. Faz 2 Flow: Stage 1a: News ingestion → news_raw Stage 1b: Raw processing → news_processed Stage 1c: Cut-off calculation Stage 1d: Price ingestion Stage 2: Sentiment scoring Stage 3: Sentiment aggregation Stage 3.5: FinBERT embedding extraction Stage 4: XGBoost training (optional) Stage 5: XGBoost snapshot Stage 5.5: TFT-ASRO inference (downloads checkpoint from HF Hub) Stage 6: Commentary generation Note: TFT-ASRO full training is handled exclusively by the weekly tft-training.yml GitHub workflow. The daily pipeline only runs inference. Args: ctx: arq context (contains redis connection) run_id: Unique identifier for this run train_model: Whether to train the XGBoost model trigger_source: Where the trigger came from (cron, manual, api) enqueued_at: ISO timestamp when job was enqueued Returns: dict with run results """ started_at = datetime.now(timezone.utc) holder_id = f"{socket.gethostname()}:{os.getpid()}" run_uuid = uuid.UUID(run_id) if isinstance(run_id, str) else run_id logger.info(f"[run_id={run_id}] Pipeline starting: trigger={trigger_source}, train_model={train_model}") # Initialize database init_db() # Get a dedicated session for this pipeline run # IMPORTANT: This session holds the advisory lock session: Session = SessionLocal() quality_state = "ok" result = {} try: # 0. Create run metrics record create_run_metrics(session, run_id, started_at) session.commit() # 1. Acquire distributed lock if not try_acquire_lock(session, PIPELINE_LOCK_KEY): logger.warning(f"[run_id={run_id}] Pipeline skipped: lock held by another process") finalize_run_metrics(session, run_id, status="skipped_locked", quality_state="skipped") session.commit() return { "run_id": run_id, "status": "skipped_locked", "message": "Another pipeline is running", } # Write lock visibility (best-effort) write_lock_visibility(session, PIPELINE_LOCK_KEY, run_id, holder_id) session.commit() logger.info(f"[run_id={run_id}] Lock acquired, executing pipeline...") # 2. Execute pipeline stages with proper commit boundaries result = await _execute_pipeline_stages_v2( session=session, run_id=run_id, run_uuid=run_uuid, train_model=train_model, ) # Determine quality state from result # More nuanced logic to avoid false alarms raw_inserted = result.get("news_raw_inserted", 0) proc_inserted = result.get("news_processed_inserted", 0) raw_error = result.get("news_raw_error") proc_error = result.get("news_processed_error") if raw_error or proc_error: # Actual errors during ingestion/processing quality_state = "degraded" result["message"] = f"Pipeline errors: {raw_error or ''} {proc_error or ''}".strip() elif raw_inserted == 0 and proc_inserted == 0: # No new data at all - could be dedup working or sources haven't updated quality_state = "stale" result["message"] = "No new articles - sources may not have updated" elif raw_inserted > 0 and proc_inserted == 0: # Got raw but nothing processed - potential dedup anomaly quality_state = "ok" # This is actually fine - all duplicates result["message"] = f"All {raw_inserted} articles were duplicates" else: quality_state = "ok" # 3. Record success finished_at = datetime.now(timezone.utc) duration = (finished_at - started_at).total_seconds() finalize_run_metrics( session, run_id, status="success", quality_state=quality_state, ) session.commit() logger.info(f"[run_id={run_id}] Pipeline completed in {duration:.1f}s") return { "run_id": run_id, "status": "success", "quality_state": quality_state, "started_at": started_at.isoformat(), "finished_at": finished_at.isoformat(), "duration_seconds": duration, "train_model": train_model, **result, } except Exception as e: logger.error(f"[run_id={run_id}] Pipeline failed: {e}", exc_info=True) try: finalize_run_metrics( session, run_id, status="failed", quality_state="failed", error_message=str(e)[:1000], ) session.commit() except Exception: session.rollback() return { "run_id": run_id, "status": "failed", "error": str(e), } finally: # Always release lock and cleanup try: release_lock(session, PIPELINE_LOCK_KEY) clear_lock_visibility(session, PIPELINE_LOCK_KEY) session.commit() except Exception: session.rollback() finally: session.close() async def _execute_pipeline_stages_v2( session: Session, run_id: str, run_uuid: uuid.UUID, train_model: bool, ) -> dict: """ Execute pipeline stages with Faz 2 news pipeline integration. Each stage has proper commit boundaries and metrics updates. """ from app.settings import get_settings settings = get_settings() result = {} # ------------------------------------------------------------------------- # Stage 1a: News ingestion → news_raw (FAZ 2) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 1a: News ingestion → news_raw") try: from pipelines.ingestion.news import ingest_news_to_raw raw_stats = ingest_news_to_raw( session=session, run_id=run_uuid, ) session.commit() result["news_raw_inserted"] = raw_stats.get("inserted", 0) result["news_raw_duplicates"] = raw_stats.get("duplicates", 0) update_run_metrics( session, run_id, news_raw_inserted=raw_stats.get("inserted", 0), news_raw_duplicates=raw_stats.get("duplicates", 0), ) session.commit() logger.info(f"[run_id={run_id}] news_raw: {raw_stats.get('inserted', 0)} inserted") except Exception as e: logger.error(f"[run_id={run_id}] Stage 1a failed: {e}") result["news_raw_error"] = str(e) session.rollback() # ------------------------------------------------------------------------- # Stage 1b: Raw → Processed (FAZ 2) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 1b: news_raw → news_processed") try: from pipelines.processing.news import process_raw_to_processed proc_stats = process_raw_to_processed( session=session, run_id=run_uuid, batch_size=200, ) session.commit() result["news_processed_inserted"] = proc_stats.get("inserted", 0) result["news_processed_duplicates"] = proc_stats.get("duplicates", 0) update_run_metrics( session, run_id, news_processed_inserted=proc_stats.get("inserted", 0), news_processed_duplicates=proc_stats.get("duplicates", 0), ) session.commit() logger.info(f"[run_id={run_id}] news_processed: {proc_stats.get('inserted', 0)} inserted") except Exception as e: logger.error(f"[run_id={run_id}] Stage 1b failed: {e}") result["news_processed_error"] = str(e) session.rollback() # ------------------------------------------------------------------------- # Stage 1c: Cut-off calculation (FAZ 2) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 1c: Computing news cut-off") try: from pipelines.cutoff import compute_news_cutoff cutoff_dt = compute_news_cutoff( run_datetime=datetime.now(timezone.utc), market_tz=settings.market_timezone, market_close=settings.market_close_time, buffer_minutes=settings.cutoff_buffer_minutes, ) result["news_cutoff_time"] = cutoff_dt.isoformat() update_run_metrics(session, run_id, news_cutoff_time=cutoff_dt) session.commit() logger.info(f"[run_id={run_id}] Cut-off: {cutoff_dt.isoformat()}") except Exception as e: logger.error(f"[run_id={run_id}] Stage 1c failed: {e}") result["cutoff_error"] = str(e) # ------------------------------------------------------------------------- # Stage 1d: Price ingestion (existing) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 1d: Price ingestion") try: from app.data_manager import ingest_prices price_stats = ingest_prices(session) session.commit() result["symbols_fetched"] = len(price_stats) result["price_bars_updated"] = sum( s.get("imported", 0) for s in price_stats.values() ) update_run_metrics( session, run_id, price_bars_updated=result["price_bars_updated"], ) session.commit() except Exception as e: logger.error(f"[run_id={run_id}] Stage 1d failed: {e}") result["price_error"] = str(e) session.rollback() # ------------------------------------------------------------------------- # Stage 2: Sentiment scoring (V2 - news_processed based) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 2: Sentiment scoring") try: from app.ai_engine import score_unscored_processed_articles scoring_stats = score_unscored_processed_articles(session) session.commit() result["articles_scored"] = int(scoring_stats.get("scored_count", 0)) result["articles_scored_v2"] = int(scoring_stats.get("scored_count", 0)) result["llm_parse_fail_count"] = int(scoring_stats.get("parse_fail_count", 0)) result["escalation_count"] = int(scoring_stats.get("escalation_count", 0)) result["fallback_count"] = int(scoring_stats.get("fallback_count", 0)) update_run_metrics( session, run_id, articles_scored_v2=result["articles_scored_v2"], llm_parse_fail_count=result["llm_parse_fail_count"], escalation_count=result["escalation_count"], fallback_count=result["fallback_count"], ) session.commit() except Exception as e: logger.error(f"[run_id={run_id}] Stage 2 failed: {e}") result["scoring_error"] = str(e) session.rollback() # ------------------------------------------------------------------------- # Stage 3: Sentiment aggregation (existing) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 3: Sentiment aggregation") try: from app.ai_engine import aggregate_daily_sentiment_v2 days_aggregated_v2 = aggregate_daily_sentiment_v2(session) session.commit() result["days_aggregated_v2"] = days_aggregated_v2 except Exception as e: logger.error(f"[run_id={run_id}] Stage 3 failed: {e}") result["aggregation_error"] = str(e) session.rollback() # ------------------------------------------------------------------------- # Stage 3.5: FinBERT embedding extraction (TFT-ASRO) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 3.5: FinBERT embedding extraction") try: from deep_learning.data.embeddings import backfill_embeddings emb_stats = backfill_embeddings(days=30, pca_dim=32, batch_size=64) session.commit() result["tft_embeddings_computed"] = emb_stats.get("embedded", 0) result["tft_embeddings_skipped"] = emb_stats.get("skipped", 0) result["tft_pca_fitted"] = emb_stats.get("pca_fitted", False) update_run_metrics( session, run_id, tft_embeddings_computed=emb_stats.get("embedded", 0), ) session.commit() logger.info( f"[run_id={run_id}] FinBERT embeddings: " f"{emb_stats.get('embedded', 0)} computed, {emb_stats.get('skipped', 0)} skipped" ) except ImportError: logger.info(f"[run_id={run_id}] Stage 3.5 skipped: deep_learning module not available") except Exception as e: logger.warning(f"[run_id={run_id}] Stage 3.5 failed (non-critical): {e}") result["tft_embedding_error"] = str(e) session.rollback() # ------------------------------------------------------------------------- # Stage 4: Model training (optional) # ------------------------------------------------------------------------- if train_model: logger.info(f"[run_id={run_id}] Stage 4: Model training") try: from app.ai_engine import train_xgboost_model, save_model_metadata_to_db train_result = train_xgboost_model(session) save_model_metadata_to_db( session, symbol="HG=F", importance=train_result.get("importance", []), features=train_result.get("features", []), metrics=train_result.get("metrics", {}), ) session.commit() result["model_trained"] = True result["model_metrics"] = train_result.get("metrics", {}) update_run_metrics( session, run_id, train_mae=train_result.get("metrics", {}).get("mae"), val_mae=train_result.get("metrics", {}).get("val_mae"), ) session.commit() except Exception as e: logger.error(f"[run_id={run_id}] Stage 4 failed: {e}") result["training_error"] = str(e) result["model_trained"] = False session.rollback() else: result["model_trained"] = False # ------------------------------------------------------------------------- # Stage 4.5: TFT-ASRO — inference only (training handled by weekly # tft-training.yml workflow; daily pipeline never retrains TFT) # ------------------------------------------------------------------------- result["tft_trained"] = False # ------------------------------------------------------------------------- # Stage 5: Generate snapshot # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 5: Generate snapshot") snapshot_report = None # Will be used by Stage 6 try: from app.inference import generate_analysis_report, save_analysis_snapshot report = generate_analysis_report(session, "HG=F") if report: # Add Faz 2 metadata report["quality_state"] = "ok" if result.get("news_processed_inserted", 0) == 0: report["quality_state"] = "degraded" report["message"] = "No fresh news data" save_analysis_snapshot(session, report, "HG=F") session.commit() result["snapshot_generated"] = True snapshot_report = report # Save for Stage 6 update_run_metrics(session, run_id, snapshot_generated=True) session.commit() else: result["snapshot_generated"] = False except Exception as e: logger.error(f"[run_id={run_id}] Stage 5 failed: {e}") result["snapshot_error"] = str(e) result["snapshot_generated"] = False session.rollback() # ------------------------------------------------------------------------- # Stage 5.5: TFT-ASRO snapshot (parallel to XGBoost snapshot) # ------------------------------------------------------------------------- logger.info(f"[run_id={run_id}] Stage 5.5: TFT-ASRO snapshot") try: from deep_learning.inference.predictor import generate_tft_analysis from deep_learning.config import get_tft_config from pathlib import Path tft_cfg = get_tft_config() ckpt = Path(tft_cfg.training.best_model_path) # If checkpoint is not cached locally, try to pull from HF Hub first if not ckpt.exists(): try: from deep_learning.models.hub import download_tft_artifacts logger.info(f"[run_id={run_id}] TFT checkpoint not found locally – attempting HF Hub download") download_tft_artifacts( local_dir=ckpt.parent, repo_id=tft_cfg.training.hf_model_repo, ) except Exception as hub_exc: logger.warning(f"[run_id={run_id}] HF Hub download failed: {hub_exc}") if ckpt.exists(): tft_report = generate_tft_analysis(session, "HG=F") if "error" not in tft_report: result["tft_snapshot_generated"] = True update_run_metrics(session, run_id, tft_snapshot_generated=True) session.commit() logger.info(f"[run_id={run_id}] TFT-ASRO snapshot generated") else: result["tft_snapshot_generated"] = False logger.warning(f"[run_id={run_id}] TFT prediction error: {tft_report.get('error')}") else: result["tft_snapshot_generated"] = False logger.info(f"[run_id={run_id}] Stage 5.5 skipped: no TFT checkpoint found (train-tft workflow has not run yet)") except ImportError: result["tft_snapshot_generated"] = False except Exception as e: logger.warning(f"[run_id={run_id}] Stage 5.5 failed (non-critical): {e}") result["tft_snapshot_generated"] = False session.rollback() # ------------------------------------------------------------------------- # Stage 6: Generate commentary (if any snapshot was generated) # ------------------------------------------------------------------------- has_xgb_snapshot = result.get("snapshot_generated") and snapshot_report has_tft_snapshot = result.get("tft_snapshot_generated") if has_xgb_snapshot or has_tft_snapshot: logger.info(f"[run_id={run_id}] Stage 6: Generate commentary") try: from app.commentary import generate_and_save_commentary report = snapshot_report or {} await generate_and_save_commentary( session=session, symbol="HG=F", current_price=report.get("current_price", 0.0), predicted_price=report.get("predicted_price", 0.0), predicted_return=report.get("predicted_return", 0.0), sentiment_index=report.get("sentiment_index", 0.0), sentiment_label=report.get("sentiment_label", "Neutral"), top_influencers=report.get("top_influencers", []), news_count=report.get("data_quality", {}).get("news_count_7d", 0), ) session.commit() result["commentary_generated"] = True update_run_metrics(session, run_id, commentary_generated=True) session.commit() except Exception as e: logger.warning(f"[run_id={run_id}] Stage 6 failed: {e}") result["commentary_generated"] = False else: logger.warning(f"[run_id={run_id}] Stage 6 skipped: no snapshot generated") result["commentary_generated"] = False return result # ============================================================================= # arq worker lifecycle # ============================================================================= async def startup(ctx: dict) -> None: """Called when worker starts.""" logger.info("Worker starting up...") init_db() async def shutdown(ctx: dict) -> None: """Called when worker shuts down.""" logger.info("Worker shutting down...")