copper-mind / worker /tasks.py
ifieryarrows's picture
Sync from GitHub (tests passed)
c23f275 verified
"""
Worker tasks for arq.
This module defines the tasks that the worker executes.
The main task is `run_pipeline` which orchestrates the entire pipeline.
Faz 2: Integrated news_raw/news_processed pipeline with proper
commit boundaries, metrics tracking, and degraded mode handling.
"""
import logging
import os
import socket
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy.orm import Session
# These imports will be updated as we refactor
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.db import SessionLocal, init_db, get_db_type
from app.settings import get_settings
from app.models import PipelineRunMetrics
from adapters.db.lock import (
PIPELINE_LOCK_KEY,
try_acquire_lock,
release_lock,
write_lock_visibility,
clear_lock_visibility,
)
logger = logging.getLogger(__name__)
# =============================================================================
# Helper functions for metrics tracking
# =============================================================================
def create_run_metrics(
session: Session,
run_id: str,
started_at: datetime,
) -> PipelineRunMetrics:
"""Create initial pipeline_run_metrics record."""
metrics = PipelineRunMetrics(
run_id=run_id,
run_started_at=started_at,
status="running",
)
session.add(metrics)
session.flush()
return metrics
def update_run_metrics(
session: Session,
run_id: str,
**kwargs,
) -> None:
"""Update pipeline_run_metrics with new values."""
metrics = session.query(PipelineRunMetrics).filter(
PipelineRunMetrics.run_id == run_id
).first()
if metrics:
for key, value in kwargs.items():
if hasattr(metrics, key):
setattr(metrics, key, value)
session.flush()
def finalize_run_metrics(
session: Session,
run_id: str,
status: str,
quality_state: str = "ok",
error_message: Optional[str] = None,
) -> None:
"""Finalize run metrics with completion status."""
completed_at = datetime.now(timezone.utc)
metrics = session.query(PipelineRunMetrics).filter(
PipelineRunMetrics.run_id == run_id
).first()
if metrics:
metrics.run_completed_at = completed_at
metrics.status = status
metrics.quality_state = quality_state
if metrics.run_started_at:
metrics.duration_seconds = (completed_at - metrics.run_started_at).total_seconds()
if error_message:
metrics.error_message = error_message
session.flush()
# =============================================================================
# Main pipeline task
# =============================================================================
async def run_pipeline(
ctx: dict,
run_id: str,
train_model: bool = False,
trigger_source: str = "unknown",
enqueued_at: str = None,
) -> dict:
"""
Main pipeline task - executed by arq worker.
This is the ONLY entrypoint for pipeline execution.
Faz 2 Flow:
Stage 1a: News ingestion → news_raw
Stage 1b: Raw processing → news_processed
Stage 1c: Cut-off calculation
Stage 1d: Price ingestion
Stage 2: Sentiment scoring
Stage 3: Sentiment aggregation
Stage 3.5: FinBERT embedding extraction
Stage 4: XGBoost training (optional)
Stage 5: XGBoost snapshot
Stage 5.5: TFT-ASRO inference (downloads checkpoint from HF Hub)
Stage 6: Commentary generation
Note: TFT-ASRO full training is handled exclusively by the weekly
tft-training.yml GitHub workflow. The daily pipeline only runs inference.
Args:
ctx: arq context (contains redis connection)
run_id: Unique identifier for this run
train_model: Whether to train the XGBoost model
trigger_source: Where the trigger came from (cron, manual, api)
enqueued_at: ISO timestamp when job was enqueued
Returns:
dict with run results
"""
started_at = datetime.now(timezone.utc)
holder_id = f"{socket.gethostname()}:{os.getpid()}"
run_uuid = uuid.UUID(run_id) if isinstance(run_id, str) else run_id
logger.info(f"[run_id={run_id}] Pipeline starting: trigger={trigger_source}, train_model={train_model}")
# Initialize database
init_db()
# Get a dedicated session for this pipeline run
# IMPORTANT: This session holds the advisory lock
session: Session = SessionLocal()
quality_state = "ok"
result = {}
try:
# 0. Create run metrics record
create_run_metrics(session, run_id, started_at)
session.commit()
# 1. Acquire distributed lock
if not try_acquire_lock(session, PIPELINE_LOCK_KEY):
logger.warning(f"[run_id={run_id}] Pipeline skipped: lock held by another process")
finalize_run_metrics(session, run_id, status="skipped_locked", quality_state="skipped")
session.commit()
return {
"run_id": run_id,
"status": "skipped_locked",
"message": "Another pipeline is running",
}
# Write lock visibility (best-effort)
write_lock_visibility(session, PIPELINE_LOCK_KEY, run_id, holder_id)
session.commit()
logger.info(f"[run_id={run_id}] Lock acquired, executing pipeline...")
# 2. Execute pipeline stages with proper commit boundaries
result = await _execute_pipeline_stages_v2(
session=session,
run_id=run_id,
run_uuid=run_uuid,
train_model=train_model,
)
# Determine quality state from result
# More nuanced logic to avoid false alarms
raw_inserted = result.get("news_raw_inserted", 0)
proc_inserted = result.get("news_processed_inserted", 0)
raw_error = result.get("news_raw_error")
proc_error = result.get("news_processed_error")
if raw_error or proc_error:
# Actual errors during ingestion/processing
quality_state = "degraded"
result["message"] = f"Pipeline errors: {raw_error or ''} {proc_error or ''}".strip()
elif raw_inserted == 0 and proc_inserted == 0:
# No new data at all - could be dedup working or sources haven't updated
quality_state = "stale"
result["message"] = "No new articles - sources may not have updated"
elif raw_inserted > 0 and proc_inserted == 0:
# Got raw but nothing processed - potential dedup anomaly
quality_state = "ok" # This is actually fine - all duplicates
result["message"] = f"All {raw_inserted} articles were duplicates"
else:
quality_state = "ok"
# 3. Record success
finished_at = datetime.now(timezone.utc)
duration = (finished_at - started_at).total_seconds()
finalize_run_metrics(
session, run_id,
status="success",
quality_state=quality_state,
)
session.commit()
logger.info(f"[run_id={run_id}] Pipeline completed in {duration:.1f}s")
return {
"run_id": run_id,
"status": "success",
"quality_state": quality_state,
"started_at": started_at.isoformat(),
"finished_at": finished_at.isoformat(),
"duration_seconds": duration,
"train_model": train_model,
**result,
}
except Exception as e:
logger.error(f"[run_id={run_id}] Pipeline failed: {e}", exc_info=True)
try:
finalize_run_metrics(
session, run_id,
status="failed",
quality_state="failed",
error_message=str(e)[:1000],
)
session.commit()
except Exception:
session.rollback()
return {
"run_id": run_id,
"status": "failed",
"error": str(e),
}
finally:
# Always release lock and cleanup
try:
release_lock(session, PIPELINE_LOCK_KEY)
clear_lock_visibility(session, PIPELINE_LOCK_KEY)
session.commit()
except Exception:
session.rollback()
finally:
session.close()
async def _execute_pipeline_stages_v2(
session: Session,
run_id: str,
run_uuid: uuid.UUID,
train_model: bool,
) -> dict:
"""
Execute pipeline stages with Faz 2 news pipeline integration.
Each stage has proper commit boundaries and metrics updates.
"""
from app.settings import get_settings
settings = get_settings()
result = {}
# -------------------------------------------------------------------------
# Stage 1a: News ingestion → news_raw (FAZ 2)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 1a: News ingestion → news_raw")
try:
from pipelines.ingestion.news import ingest_news_to_raw
raw_stats = ingest_news_to_raw(
session=session,
run_id=run_uuid,
)
session.commit()
result["news_raw_inserted"] = raw_stats.get("inserted", 0)
result["news_raw_duplicates"] = raw_stats.get("duplicates", 0)
update_run_metrics(
session, run_id,
news_raw_inserted=raw_stats.get("inserted", 0),
news_raw_duplicates=raw_stats.get("duplicates", 0),
)
session.commit()
logger.info(f"[run_id={run_id}] news_raw: {raw_stats.get('inserted', 0)} inserted")
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 1a failed: {e}")
result["news_raw_error"] = str(e)
session.rollback()
# -------------------------------------------------------------------------
# Stage 1b: Raw → Processed (FAZ 2)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 1b: news_raw → news_processed")
try:
from pipelines.processing.news import process_raw_to_processed
proc_stats = process_raw_to_processed(
session=session,
run_id=run_uuid,
batch_size=200,
)
session.commit()
result["news_processed_inserted"] = proc_stats.get("inserted", 0)
result["news_processed_duplicates"] = proc_stats.get("duplicates", 0)
update_run_metrics(
session, run_id,
news_processed_inserted=proc_stats.get("inserted", 0),
news_processed_duplicates=proc_stats.get("duplicates", 0),
)
session.commit()
logger.info(f"[run_id={run_id}] news_processed: {proc_stats.get('inserted', 0)} inserted")
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 1b failed: {e}")
result["news_processed_error"] = str(e)
session.rollback()
# -------------------------------------------------------------------------
# Stage 1c: Cut-off calculation (FAZ 2)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 1c: Computing news cut-off")
try:
from pipelines.cutoff import compute_news_cutoff
cutoff_dt = compute_news_cutoff(
run_datetime=datetime.now(timezone.utc),
market_tz=settings.market_timezone,
market_close=settings.market_close_time,
buffer_minutes=settings.cutoff_buffer_minutes,
)
result["news_cutoff_time"] = cutoff_dt.isoformat()
update_run_metrics(session, run_id, news_cutoff_time=cutoff_dt)
session.commit()
logger.info(f"[run_id={run_id}] Cut-off: {cutoff_dt.isoformat()}")
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 1c failed: {e}")
result["cutoff_error"] = str(e)
# -------------------------------------------------------------------------
# Stage 1d: Price ingestion (existing)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 1d: Price ingestion")
try:
from app.data_manager import ingest_prices
price_stats = ingest_prices(session)
session.commit()
result["symbols_fetched"] = len(price_stats)
result["price_bars_updated"] = sum(
s.get("imported", 0) for s in price_stats.values()
)
update_run_metrics(
session, run_id,
price_bars_updated=result["price_bars_updated"],
)
session.commit()
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 1d failed: {e}")
result["price_error"] = str(e)
session.rollback()
# -------------------------------------------------------------------------
# Stage 2: Sentiment scoring (V2 - news_processed based)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 2: Sentiment scoring")
try:
from app.ai_engine import score_unscored_processed_articles
scoring_stats = score_unscored_processed_articles(session)
session.commit()
result["articles_scored"] = int(scoring_stats.get("scored_count", 0))
result["articles_scored_v2"] = int(scoring_stats.get("scored_count", 0))
result["llm_parse_fail_count"] = int(scoring_stats.get("parse_fail_count", 0))
result["escalation_count"] = int(scoring_stats.get("escalation_count", 0))
result["fallback_count"] = int(scoring_stats.get("fallback_count", 0))
update_run_metrics(
session,
run_id,
articles_scored_v2=result["articles_scored_v2"],
llm_parse_fail_count=result["llm_parse_fail_count"],
escalation_count=result["escalation_count"],
fallback_count=result["fallback_count"],
)
session.commit()
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 2 failed: {e}")
result["scoring_error"] = str(e)
session.rollback()
# -------------------------------------------------------------------------
# Stage 3: Sentiment aggregation (existing)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 3: Sentiment aggregation")
try:
from app.ai_engine import aggregate_daily_sentiment_v2
days_aggregated_v2 = aggregate_daily_sentiment_v2(session)
session.commit()
result["days_aggregated_v2"] = days_aggregated_v2
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 3 failed: {e}")
result["aggregation_error"] = str(e)
session.rollback()
# -------------------------------------------------------------------------
# Stage 3.5: FinBERT embedding extraction (TFT-ASRO)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 3.5: FinBERT embedding extraction")
try:
from deep_learning.data.embeddings import backfill_embeddings
emb_stats = backfill_embeddings(days=30, pca_dim=32, batch_size=64)
session.commit()
result["tft_embeddings_computed"] = emb_stats.get("embedded", 0)
result["tft_embeddings_skipped"] = emb_stats.get("skipped", 0)
result["tft_pca_fitted"] = emb_stats.get("pca_fitted", False)
update_run_metrics(
session, run_id,
tft_embeddings_computed=emb_stats.get("embedded", 0),
)
session.commit()
logger.info(
f"[run_id={run_id}] FinBERT embeddings: "
f"{emb_stats.get('embedded', 0)} computed, {emb_stats.get('skipped', 0)} skipped"
)
except ImportError:
logger.info(f"[run_id={run_id}] Stage 3.5 skipped: deep_learning module not available")
except Exception as e:
logger.warning(f"[run_id={run_id}] Stage 3.5 failed (non-critical): {e}")
result["tft_embedding_error"] = str(e)
session.rollback()
# -------------------------------------------------------------------------
# Stage 4: Model training (optional)
# -------------------------------------------------------------------------
if train_model:
logger.info(f"[run_id={run_id}] Stage 4: Model training")
try:
from app.ai_engine import train_xgboost_model, save_model_metadata_to_db
train_result = train_xgboost_model(session)
save_model_metadata_to_db(
session,
symbol="HG=F",
importance=train_result.get("importance", []),
features=train_result.get("features", []),
metrics=train_result.get("metrics", {}),
)
session.commit()
result["model_trained"] = True
result["model_metrics"] = train_result.get("metrics", {})
update_run_metrics(
session, run_id,
train_mae=train_result.get("metrics", {}).get("mae"),
val_mae=train_result.get("metrics", {}).get("val_mae"),
)
session.commit()
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 4 failed: {e}")
result["training_error"] = str(e)
result["model_trained"] = False
session.rollback()
else:
result["model_trained"] = False
# -------------------------------------------------------------------------
# Stage 4.5: TFT-ASRO — inference only (training handled by weekly
# tft-training.yml workflow; daily pipeline never retrains TFT)
# -------------------------------------------------------------------------
result["tft_trained"] = False
# -------------------------------------------------------------------------
# Stage 5: Generate snapshot
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 5: Generate snapshot")
snapshot_report = None # Will be used by Stage 6
try:
from app.inference import generate_analysis_report, save_analysis_snapshot
report = generate_analysis_report(session, "HG=F")
if report:
# Add Faz 2 metadata
report["quality_state"] = "ok"
if result.get("news_processed_inserted", 0) == 0:
report["quality_state"] = "degraded"
report["message"] = "No fresh news data"
save_analysis_snapshot(session, report, "HG=F")
session.commit()
result["snapshot_generated"] = True
snapshot_report = report # Save for Stage 6
update_run_metrics(session, run_id, snapshot_generated=True)
session.commit()
else:
result["snapshot_generated"] = False
except Exception as e:
logger.error(f"[run_id={run_id}] Stage 5 failed: {e}")
result["snapshot_error"] = str(e)
result["snapshot_generated"] = False
session.rollback()
# -------------------------------------------------------------------------
# Stage 5.5: TFT-ASRO snapshot (parallel to XGBoost snapshot)
# -------------------------------------------------------------------------
logger.info(f"[run_id={run_id}] Stage 5.5: TFT-ASRO snapshot")
try:
from deep_learning.inference.predictor import generate_tft_analysis
from deep_learning.config import get_tft_config
from pathlib import Path
tft_cfg = get_tft_config()
ckpt = Path(tft_cfg.training.best_model_path)
# If checkpoint is not cached locally, try to pull from HF Hub first
if not ckpt.exists():
try:
from deep_learning.models.hub import download_tft_artifacts
logger.info(f"[run_id={run_id}] TFT checkpoint not found locally – attempting HF Hub download")
download_tft_artifacts(
local_dir=ckpt.parent,
repo_id=tft_cfg.training.hf_model_repo,
)
except Exception as hub_exc:
logger.warning(f"[run_id={run_id}] HF Hub download failed: {hub_exc}")
if ckpt.exists():
tft_report = generate_tft_analysis(session, "HG=F")
if "error" not in tft_report:
result["tft_snapshot_generated"] = True
update_run_metrics(session, run_id, tft_snapshot_generated=True)
session.commit()
logger.info(f"[run_id={run_id}] TFT-ASRO snapshot generated")
else:
result["tft_snapshot_generated"] = False
logger.warning(f"[run_id={run_id}] TFT prediction error: {tft_report.get('error')}")
else:
result["tft_snapshot_generated"] = False
logger.info(f"[run_id={run_id}] Stage 5.5 skipped: no TFT checkpoint found (train-tft workflow has not run yet)")
except ImportError:
result["tft_snapshot_generated"] = False
except Exception as e:
logger.warning(f"[run_id={run_id}] Stage 5.5 failed (non-critical): {e}")
result["tft_snapshot_generated"] = False
session.rollback()
# -------------------------------------------------------------------------
# Stage 6: Generate commentary (if any snapshot was generated)
# -------------------------------------------------------------------------
has_xgb_snapshot = result.get("snapshot_generated") and snapshot_report
has_tft_snapshot = result.get("tft_snapshot_generated")
if has_xgb_snapshot or has_tft_snapshot:
logger.info(f"[run_id={run_id}] Stage 6: Generate commentary")
try:
from app.commentary import generate_and_save_commentary
report = snapshot_report or {}
await generate_and_save_commentary(
session=session,
symbol="HG=F",
current_price=report.get("current_price", 0.0),
predicted_price=report.get("predicted_price", 0.0),
predicted_return=report.get("predicted_return", 0.0),
sentiment_index=report.get("sentiment_index", 0.0),
sentiment_label=report.get("sentiment_label", "Neutral"),
top_influencers=report.get("top_influencers", []),
news_count=report.get("data_quality", {}).get("news_count_7d", 0),
)
session.commit()
result["commentary_generated"] = True
update_run_metrics(session, run_id, commentary_generated=True)
session.commit()
except Exception as e:
logger.warning(f"[run_id={run_id}] Stage 6 failed: {e}")
result["commentary_generated"] = False
else:
logger.warning(f"[run_id={run_id}] Stage 6 skipped: no snapshot generated")
result["commentary_generated"] = False
return result
# =============================================================================
# arq worker lifecycle
# =============================================================================
async def startup(ctx: dict) -> None:
"""Called when worker starts."""
logger.info("Worker starting up...")
init_db()
async def shutdown(ctx: dict) -> None:
"""Called when worker shuts down."""
logger.info("Worker shutting down...")