JNU-TSB / runtime.py
HONGRIZON's picture
Upload 18 files
cf02581 verified
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
import pandas as pd
import torch
try:
from .configuration_jnu_tsb import JNUTSBConfig
from .event_extractor import COVARIATE_COLUMNS, EventExtractor
except ImportError: # pragma: no cover - local execution fallback
from configuration_jnu_tsb import JNUTSBConfig
from event_extractor import COVARIATE_COLUMNS, EventExtractor
class JNUTSBRuntime:
"""Runtime used by the model wrapper, pipeline, Endpoint handler, Gradio, and R examples.
Routes inputs into three paths:
1. stock only -> Chronos-2 forecast
2. news only -> event extraction and daily covariates
3. stock + news -> news covariates + stock context -> Chronos-2 forecast
"""
def __init__(
self,
config: Union[JNUTSBConfig, Dict[str, Any]],
chronos_device_map: Optional[str] = None,
llm_device_map: Optional[str] = None,
) -> None:
if isinstance(config, dict):
config = JNUTSBConfig(**config)
self.config = config
self.chronos_device_map = chronos_device_map or os.getenv("JNU_TSB_CHRONOS_DEVICE_MAP", "cpu")
self.llm_device_map = llm_device_map or os.getenv("JNU_TSB_LLM_DEVICE_MAP", "cpu")
self._chronos = None
self._llm_pipe = None
self._extractor = None
@classmethod
def from_config(cls, config: Union[JNUTSBConfig, Dict[str, Any]], **kwargs: Any) -> "JNUTSBRuntime":
return cls(config=config, **kwargs)
@classmethod
def from_config_dir(cls, path: Union[str, os.PathLike[str]], **kwargs: Any) -> "JNUTSBRuntime":
path = Path(path)
with open(path / "config.json", "r", encoding="utf-8") as f:
payload = json.load(f)
return cls(config=payload, **kwargs)
@property
def chronos(self):
if self._chronos is None:
try:
from chronos import Chronos2Pipeline
except Exception as exc: # pragma: no cover
raise ImportError(
"chronos-forecasting is required for Chronos-2 inference. "
"Install it with: pip install chronos-forecasting"
) from exc
self._chronos = Chronos2Pipeline.from_pretrained(
self.config.chronos_model_id,
device_map=self.chronos_device_map,
)
return self._chronos
@property
def extractor(self) -> EventExtractor:
if self._extractor is None:
self._extractor = EventExtractor(
generate_fn=self._generate_with_polyglot if self.config.use_llm_extractor else None,
categories=self.config.event_categories,
use_llm=self.config.use_llm_extractor,
)
return self._extractor
def _generate_with_polyglot(self, prompt: str) -> str:
if self._llm_pipe is None:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline as hf_pipeline
tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.config.llm_model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map=self.llm_device_map,
)
self._llm_pipe = hf_pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
output = self._llm_pipe(
prompt,
max_new_tokens=96,
do_sample=False,
return_full_text=False,
)
if isinstance(output, list) and output:
return output[0].get("generated_text", "")
return str(output)
def predict(
self,
inputs: Optional[Dict[str, Any]] = None,
prediction_length: Optional[int] = None,
quantile_levels: Optional[Sequence[float]] = None,
use_llm_extractor: Optional[bool] = None,
allow_naive_fallback: Optional[bool] = None,
**kwargs: Any,
) -> Dict[str, Any]:
payload: Dict[str, Any] = dict(inputs or {})
payload.update(kwargs)
if use_llm_extractor is not None and bool(use_llm_extractor) != self.config.use_llm_extractor:
# Rebuild extractor with the requested setting for this runtime instance.
self.config.use_llm_extractor = bool(use_llm_extractor)
self._extractor = None
prediction_length = int(prediction_length or self.config.prediction_length)
quantile_levels = list(quantile_levels or self.config.quantile_levels)
allow_naive_fallback = self.config.allow_naive_fallback if allow_naive_fallback is None else bool(allow_naive_fallback)
news = payload.get("news")
stock = payload.get("stock")
future_news = payload.get("future_news")
future_covariates = payload.get("future_covariates")
has_news = bool(news)
stock_df = self._prepare_stock_df(stock)
has_stock = stock_df is not None and not stock_df.empty
if has_news and has_stock:
context_df = self._merge_news_covariates(stock_df, news)
future_df = self._prepare_future_covariates(
stock_df=context_df,
future_news=future_news,
future_covariates=future_covariates,
prediction_length=prediction_length,
)
return self._forecast(
context_df=context_df,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
route="hybrid",
future_df=future_df,
allow_naive_fallback=allow_naive_fallback,
)
if has_stock:
return self._forecast(
context_df=stock_df,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
route="chronos_only",
future_df=None,
allow_naive_fallback=allow_naive_fallback,
)
if has_news:
events = [self.extractor.extract(item.get("title") or item.get("headline") or item.get("text") or "") for item in news]
daily_covariates = self.extractor.aggregate_to_daily(news)
return {
"route": "text_only",
"repo_id": self.config.repo_id,
"events": events,
"daily_covariates": self._df_to_records(daily_covariates),
}
raise ValueError("JNU-TSB expects at least one of: stock, news.")
def _forecast(
self,
context_df: pd.DataFrame,
prediction_length: int,
quantile_levels: Sequence[float],
route: str,
future_df: Optional[pd.DataFrame] = None,
allow_naive_fallback: bool = True,
) -> Dict[str, Any]:
try:
kwargs = dict(
prediction_length=prediction_length,
quantile_levels=list(quantile_levels),
id_column=self.config.id_column,
timestamp_column=self.config.timestamp_column,
target=self.config.target_column,
)
if future_df is not None and not future_df.empty:
pred = self.chronos.predict_df(context_df, future_df=future_df, **kwargs)
else:
pred = self.chronos.predict_df(context_df, **kwargs)
return {
"route": route,
"repo_id": self.config.repo_id,
"engine": self.config.chronos_model_id,
"forecast": self._df_to_records(pred),
"used_naive_fallback": False,
}
except Exception as exc:
if not allow_naive_fallback:
raise
pred = self._naive_forecast(context_df, prediction_length, quantile_levels)
return {
"route": route,
"repo_id": self.config.repo_id,
"engine": "naive_last_value_fallback",
"forecast": self._df_to_records(pred),
"used_naive_fallback": True,
"warning": f"Chronos-2 inference failed or was unavailable: {type(exc).__name__}: {exc}",
}
def _prepare_stock_df(self, stock: Any) -> Optional[pd.DataFrame]:
if stock is None:
return None
if isinstance(stock, pd.DataFrame):
df = stock.copy()
elif isinstance(stock, list):
df = pd.DataFrame(stock)
elif isinstance(stock, dict):
df = pd.DataFrame(stock)
else:
raise TypeError("stock must be a pandas DataFrame, list of dicts, or dict of columns.")
if df.empty:
return df
timestamp_col = self.config.timestamp_column
if timestamp_col not in df.columns:
for cand in ("date", "Date", "datetime", "time"):
if cand in df.columns:
df = df.rename(columns={cand: timestamp_col})
break
target_col = self.config.target_column
if target_col not in df.columns:
for cand in ("close", "Close", "price", "value", "y"):
if cand in df.columns:
df = df.rename(columns={cand: target_col})
break
if timestamp_col not in df.columns or target_col not in df.columns:
raise ValueError(f"stock must contain '{timestamp_col}' and '{target_col}' columns.")
if self.config.id_column not in df.columns:
df[self.config.id_column] = self.config.default_item_id
df[timestamp_col] = pd.to_datetime(df[timestamp_col])
df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True)
return df
def _prepare_future_df(self, data: Any) -> Optional[pd.DataFrame]:
if data is None:
return None
if isinstance(data, pd.DataFrame):
df = data.copy()
elif isinstance(data, list):
df = pd.DataFrame(data)
elif isinstance(data, dict):
df = pd.DataFrame(data)
else:
raise TypeError("future_covariates must be a pandas DataFrame, list of dicts, or dict of columns.")
if df.empty:
return df
timestamp_col = self.config.timestamp_column
if timestamp_col not in df.columns:
for cand in ("date", "Date", "datetime", "time"):
if cand in df.columns:
df = df.rename(columns={cand: timestamp_col})
break
if timestamp_col not in df.columns:
raise ValueError(f"future_covariates must contain a '{timestamp_col}' column.")
if self.config.id_column not in df.columns:
df[self.config.id_column] = self.config.default_item_id
df[timestamp_col] = pd.to_datetime(df[timestamp_col])
df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True)
return df
def _merge_news_covariates(self, stock_df: pd.DataFrame, news: Iterable[Dict[str, Any]]) -> pd.DataFrame:
cov = self.extractor.aggregate_to_daily(news)
context = stock_df.copy()
day_col = "__day__"
context[day_col] = pd.to_datetime(context[self.config.timestamp_column]).dt.floor("D")
cov = cov.rename(columns={"timestamp": day_col})
merged = context.merge(cov, on=day_col, how="left").drop(columns=[day_col])
for col in COVARIATE_COLUMNS:
if col in merged.columns:
merged[col] = merged[col].fillna(0).astype(float)
return merged
def _prepare_future_covariates(
self,
stock_df: pd.DataFrame,
future_news: Optional[Iterable[Dict[str, Any]]],
future_covariates: Any,
prediction_length: int,
) -> Optional[pd.DataFrame]:
if future_covariates is not None:
fut = self._prepare_future_df(future_covariates)
if fut is not None and not fut.empty:
return fut.drop(columns=[self.config.target_column], errors="ignore")
if not future_news:
return None
first_id = stock_df[self.config.id_column].iloc[0]
last_ts = pd.to_datetime(stock_df[self.config.timestamp_column]).max()
freq = pd.infer_freq(pd.to_datetime(stock_df[self.config.timestamp_column]).drop_duplicates().sort_values()) or "D"
future_dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:]
base = pd.DataFrame({
self.config.timestamp_column: future_dates,
self.config.id_column: first_id,
})
cov = self.extractor.aggregate_to_daily(future_news)
if cov.empty:
return base
cov_day = cov.rename(columns={"timestamp": "__day__"})
base["__day__"] = pd.to_datetime(base[self.config.timestamp_column]).dt.floor("D")
merged = base.merge(cov_day, on="__day__", how="left").drop(columns=["__day__"])
for col in COVARIATE_COLUMNS:
if col in merged.columns:
merged[col] = merged[col].fillna(0).astype(float)
return merged
def _naive_forecast(self, context_df: pd.DataFrame, prediction_length: int, quantile_levels: Sequence[float]) -> pd.DataFrame:
timestamp_col = self.config.timestamp_column
target_col = self.config.target_column
id_col = self.config.id_column
rows: List[Dict[str, Any]] = []
for item_id, group in context_df.groupby(id_col):
group = group.sort_values(timestamp_col)
last_ts = pd.to_datetime(group[timestamp_col].iloc[-1])
last_value = float(group[target_col].iloc[-1])
freq = pd.infer_freq(pd.to_datetime(group[timestamp_col]).drop_duplicates().sort_values()) or "D"
dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:]
for ts in dates:
row: Dict[str, Any] = {id_col: item_id, timestamp_col: ts}
for q in quantile_levels:
row[str(q)] = last_value
row[f"q{q}"] = last_value
row["mean"] = last_value
row["prediction"] = last_value
rows.append(row)
return pd.DataFrame(rows)
def _df_to_records(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
out = df.copy()
for col in out.columns:
if pd.api.types.is_datetime64_any_dtype(out[col]):
out[col] = out[col].astype(str)
return out.to_dict(orient="records")