Spaces:
Running
Running
File size: 5,565 Bytes
d317049 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """Backfill market-date sentiment alignment for weekly TFT training."""
from __future__ import annotations
import argparse
import json
import pathlib
import sys
from datetime import date, datetime, timezone
from typing import Any
BACKEND_ROOT = pathlib.Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
from sqlalchemy import func
from app.ai_engine import aggregate_daily_sentiment_v2
from app.db import SessionLocal, init_db
from app.models import DailySentimentV2, NewsProcessed, NewsRaw, NewsSentimentV2
from pipelines.market_calendar import assign_market_date, is_after_close_news
def _parse_date(value: str) -> datetime:
if value.lower() == "today":
return datetime.now(timezone.utc)
parsed = datetime.fromisoformat(value)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed
def _audit_counts(session, start: datetime, end: datetime) -> dict[str, Any]:
base = (
session.query(NewsSentimentV2, NewsRaw)
.join(NewsProcessed, NewsSentimentV2.news_processed_id == NewsProcessed.id)
.join(NewsRaw, NewsProcessed.raw_id == NewsRaw.id)
.filter(NewsRaw.published_at >= start, NewsRaw.published_at <= end)
)
rows = base.all()
null_before = sum(1 for sentiment, _raw in rows if sentiment.market_date is None)
mappings = []
after_close = 0
weekend = 0
for _sentiment, raw in rows:
md = assign_market_date(raw.published_at)
local_weekend = raw.published_at.weekday() >= 5
after = is_after_close_news(raw.published_at)
after_close += int(after)
weekend += int(local_weekend)
if len(mappings) < 10 and (after or local_weekend):
mappings.append(
{
"published_at_utc": raw.published_at.isoformat(),
"market_date": md.isoformat(),
}
)
if len(mappings) < 5:
for _sentiment, raw in rows[: 10 - len(mappings)]:
mappings.append(
{
"published_at_utc": raw.published_at.isoformat(),
"market_date": assign_market_date(raw.published_at).isoformat(),
}
)
return {
"rows_to_update": len(rows),
"null_market_date_before": null_before,
"null_market_date_after_expected": 0 if rows else null_before,
"after_close_shift_count": after_close,
"weekend_shift_count": weekend,
"daily_market_date_null_before": session.query(DailySentimentV2)
.filter(DailySentimentV2.market_date.is_(None))
.count(),
"sample_mappings": mappings,
}
def _backfill_news_sentiments(session, start: datetime, end: datetime) -> int:
rows = (
session.query(NewsSentimentV2, NewsRaw)
.join(NewsProcessed, NewsSentimentV2.news_processed_id == NewsProcessed.id)
.join(NewsRaw, NewsProcessed.raw_id == NewsRaw.id)
.filter(NewsRaw.published_at >= start, NewsRaw.published_at <= end)
.all()
)
for sentiment, raw in rows:
sentiment.market_date = assign_market_date(raw.published_at)
sentiment.available_at = raw.fetched_at or raw.published_at
sentiment.cutoff_version = "market_close_v1"
return len(rows)
def _duplicate_market_dates(session) -> list[dict[str, Any]]:
rows = (
session.query(DailySentimentV2.market_date, func.count(DailySentimentV2.id))
.filter(DailySentimentV2.market_date.is_not(None))
.group_by(DailySentimentV2.market_date)
.having(func.count(DailySentimentV2.id) > 1)
.all()
)
return [{"market_date": str(md), "count": int(count)} for md, count in rows]
def run_backfill(start: datetime, end: datetime, *, dry_run: bool) -> dict[str, Any]:
init_db()
with SessionLocal() as session:
audit = _audit_counts(session, start, end)
audit["dry_run"] = dry_run
audit["start"] = start.date().isoformat()
audit["end"] = end.date().isoformat()
if dry_run:
audit["daily_aggregate_rows_rebuilt"] = 0
audit["duplicates_after_expected"] = _duplicate_market_dates(session)
return audit
updated = _backfill_news_sentiments(session, start, end)
session.commit()
rebuilt = aggregate_daily_sentiment_v2(session)
session.commit()
audit["rows_updated"] = updated
audit["daily_aggregate_rows_rebuilt"] = rebuilt
audit["news_market_date_null_after"] = (
session.query(NewsSentimentV2)
.filter(NewsSentimentV2.market_date.is_(None))
.count()
)
audit["daily_market_date_null_after"] = (
session.query(DailySentimentV2)
.filter(DailySentimentV2.market_date.is_(None))
.count()
)
audit["duplicates_after"] = _duplicate_market_dates(session)
return audit
def main() -> int:
parser = argparse.ArgumentParser(description="Backfill market_date for TFT weekly sentiment features")
parser.add_argument("--start", default="2020-01-01")
parser.add_argument("--end", default="today")
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()
audit = run_backfill(_parse_date(args.start), _parse_date(args.end), dry_run=bool(args.dry_run))
print(json.dumps(audit, indent=2, default=str))
return 0
if __name__ == "__main__":
raise SystemExit(main())
|