ifieryarrows commited on
Commit
dff0b7c
·
verified ·
1 Parent(s): 3ef755a

Sync from GitHub (tests passed)

Browse files
app/ai_engine.py CHANGED
@@ -1409,7 +1409,11 @@ def score_unscored_processed_articles(
1409
 
1410
  llm_results_by_id: dict[int, dict] = {}
1411
  llm_candidates: list[dict] = []
1412
- global_rate_limited = getattr(score_unscored_processed_articles, "_rate_limited", False)
 
 
 
 
1413
 
1414
  if settings.openrouter_api_key and llm_budget_remaining > 0 and not global_rate_limited:
1415
  llm_take = min(len(chunk_items), llm_budget_remaining)
@@ -1430,10 +1434,14 @@ def score_unscored_processed_articles(
1430
  fast_model = str(llm_bundle.get("model_fast", fast_model))
1431
  reliable_model = str(llm_bundle.get("model_reliable", reliable_model))
1432
 
1433
- # If LLM returned 100% fail and flagged rate limit:
 
1434
  if llm_bundle.get("rate_limited", False):
1435
- score_unscored_processed_articles._rate_limited = True
1436
- logger.warning("V2 batch hit rate limit - disabling LLM for remaining chunks in this run.")
 
 
 
1437
 
1438
  except Exception as exc:
1439
  logger.warning("V2 LLM scoring failed for chunk starting at %s: %s", chunk_idx, exc)
 
1409
 
1410
  llm_results_by_id: dict[int, dict] = {}
1411
  llm_candidates: list[dict] = []
1412
+
1413
+ # Rate-limit flag is keyed to today's UTC date so it resets automatically at midnight.
1414
+ today_utc = datetime.now(timezone.utc).date().isoformat()
1415
+ rate_limited_date = getattr(score_unscored_processed_articles, "_rate_limited_date", None)
1416
+ global_rate_limited = rate_limited_date == today_utc
1417
 
1418
  if settings.openrouter_api_key and llm_budget_remaining > 0 and not global_rate_limited:
1419
  llm_take = min(len(chunk_items), llm_budget_remaining)
 
1434
  fast_model = str(llm_bundle.get("model_fast", fast_model))
1435
  reliable_model = str(llm_bundle.get("model_reliable", reliable_model))
1436
 
1437
+ # If LLM returned 100% fail and flagged rate limit, mark for today's UTC date.
1438
+ # Flag resets automatically the next UTC day when the daily limit refreshes.
1439
  if llm_bundle.get("rate_limited", False):
1440
+ score_unscored_processed_articles._rate_limited_date = datetime.now(timezone.utc).date().isoformat()
1441
+ logger.warning(
1442
+ "V2 batch hit OpenRouter daily rate limit - LLM scoring disabled for the rest of UTC day %s.",
1443
+ score_unscored_processed_articles._rate_limited_date,
1444
+ )
1445
 
1446
  except Exception as exc:
1447
  logger.warning("V2 LLM scoring failed for chunk starting at %s: %s", chunk_idx, exc)
app/features.py CHANGED
@@ -18,6 +18,8 @@ import pandas as pd
18
  from sqlalchemy import func
19
  from sqlalchemy.orm import Session
20
 
 
 
21
  from app.db import SessionLocal
22
  from app.models import PriceBar, DailySentiment, DailySentimentV2
23
  from app.settings import get_settings
 
18
  from sqlalchemy import func
19
  from sqlalchemy.orm import Session
20
 
21
+ pd.set_option("future.no_silent_downcasting", True)
22
+
23
  from app.db import SessionLocal
24
  from app.models import PriceBar, DailySentiment, DailySentimentV2
25
  from app.settings import get_settings
deep_learning/inference/predictor.py CHANGED
@@ -19,6 +19,8 @@ from functools import lru_cache
19
  from pathlib import Path
20
  from typing import Any, Dict, Optional
21
 
 
 
22
  import numpy as np
23
  import pandas as pd
24
 
@@ -142,7 +144,8 @@ class TFTPredictor:
142
  logger.error("Failed to create inference dataset: %s", exc)
143
  return {"error": str(exc)}
144
 
145
- dl = ds.to_dataloader(train=False, batch_size=1, num_workers=0)
 
146
 
147
  try:
148
  import torch
 
19
  from pathlib import Path
20
  from typing import Any, Dict, Optional
21
 
22
+ import os
23
+
24
  import numpy as np
25
  import pandas as pd
26
 
 
144
  logger.error("Failed to create inference dataset: %s", exc)
145
  return {"error": str(exc)}
146
 
147
+ _nw = 0 if os.name == "nt" else 2
148
+ dl = ds.to_dataloader(train=False, batch_size=1, num_workers=_nw)
149
 
150
  try:
151
  import torch
deep_learning/models/tft_copper.py CHANGED
@@ -179,15 +179,19 @@ def load_tft_model(
179
  # Interpretation helpers
180
  # ---------------------------------------------------------------------------
181
 
182
- def get_variable_importance(model) -> Dict[str, float]:
183
  """
184
  Extract learned variable importance from the TFT's Variable Selection Networks.
185
 
186
  Returns a dict mapping feature name -> normalised importance score.
 
 
187
  """
 
 
188
  try:
189
  interpretation = model.interpret_output(
190
- model.predict(model.val_dataloader(), return_x=True),
191
  reduction="sum",
192
  )
193
  importance = interpretation.get("encoder_variables", {})
 
179
  # Interpretation helpers
180
  # ---------------------------------------------------------------------------
181
 
182
+ def get_variable_importance(model, val_dataloader=None) -> Dict[str, float]:
183
  """
184
  Extract learned variable importance from the TFT's Variable Selection Networks.
185
 
186
  Returns a dict mapping feature name -> normalised importance score.
187
+ val_dataloader must be passed explicitly (model.val_dataloader() only works
188
+ inside a Lightning Trainer context and raises an error otherwise).
189
  """
190
+ if val_dataloader is None:
191
+ return {}
192
  try:
193
  interpretation = model.interpret_output(
194
+ model.predict(val_dataloader, return_x=True),
195
  reduction="sum",
196
  )
197
  importance = interpretation.get("encoder_variables", {})
deep_learning/training/trainer.py CHANGED
@@ -230,7 +230,7 @@ def train_tft_model(
230
  logger.info("Test metrics: %s", {k: f"{v:.4f}" for k, v in test_metrics.items()})
231
 
232
  # ---- 8. Variable importance ----
233
- var_importance = get_variable_importance(model)
234
 
235
  # ---- 9. Persist metadata ----
236
  result = {
 
230
  logger.info("Test metrics: %s", {k: f"{v:.4f}" for k, v in test_metrics.items()})
231
 
232
  # ---- 8. Variable importance ----
233
+ var_importance = get_variable_importance(model, val_dataloader=val_dl)
234
 
235
  # ---- 9. Persist metadata ----
236
  result = {