from __future__ import annotations import json import os from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Union import pandas as pd import torch try: from .configuration_jnu_tsb import JNUTSBConfig from .event_extractor import COVARIATE_COLUMNS, EventExtractor except ImportError: # pragma: no cover - local execution fallback from configuration_jnu_tsb import JNUTSBConfig from event_extractor import COVARIATE_COLUMNS, EventExtractor class JNUTSBRuntime: """Runtime used by the model wrapper, pipeline, Endpoint handler, Gradio, and R examples. Routes inputs into three paths: 1. stock only -> Chronos-2 forecast 2. news only -> event extraction and daily covariates 3. stock + news -> news covariates + stock context -> Chronos-2 forecast """ def __init__( self, config: Union[JNUTSBConfig, Dict[str, Any]], chronos_device_map: Optional[str] = None, llm_device_map: Optional[str] = None, ) -> None: if isinstance(config, dict): config = JNUTSBConfig(**config) self.config = config self.chronos_device_map = chronos_device_map or os.getenv("JNU_TSB_CHRONOS_DEVICE_MAP", "cpu") self.llm_device_map = llm_device_map or os.getenv("JNU_TSB_LLM_DEVICE_MAP", "cpu") self._chronos = None self._llm_pipe = None self._extractor = None @classmethod def from_config(cls, config: Union[JNUTSBConfig, Dict[str, Any]], **kwargs: Any) -> "JNUTSBRuntime": return cls(config=config, **kwargs) @classmethod def from_config_dir(cls, path: Union[str, os.PathLike[str]], **kwargs: Any) -> "JNUTSBRuntime": path = Path(path) with open(path / "config.json", "r", encoding="utf-8") as f: payload = json.load(f) return cls(config=payload, **kwargs) @property def chronos(self): if self._chronos is None: try: from chronos import Chronos2Pipeline except Exception as exc: # pragma: no cover raise ImportError( "chronos-forecasting is required for Chronos-2 inference. " "Install it with: pip install chronos-forecasting" ) from exc self._chronos = Chronos2Pipeline.from_pretrained( self.config.chronos_model_id, device_map=self.chronos_device_map, ) return self._chronos @property def extractor(self) -> EventExtractor: if self._extractor is None: self._extractor = EventExtractor( generate_fn=self._generate_with_polyglot if self.config.use_llm_extractor else None, categories=self.config.event_categories, use_llm=self.config.use_llm_extractor, ) return self._extractor def _generate_with_polyglot(self, prompt: str) -> str: if self._llm_pipe is None: from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline as hf_pipeline tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_id) model = AutoModelForCausalLM.from_pretrained( self.config.llm_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map=self.llm_device_map, ) self._llm_pipe = hf_pipeline( "text-generation", model=model, tokenizer=tokenizer, ) output = self._llm_pipe( prompt, max_new_tokens=96, do_sample=False, return_full_text=False, ) if isinstance(output, list) and output: return output[0].get("generated_text", "") return str(output) def predict( self, inputs: Optional[Dict[str, Any]] = None, prediction_length: Optional[int] = None, quantile_levels: Optional[Sequence[float]] = None, use_llm_extractor: Optional[bool] = None, allow_naive_fallback: Optional[bool] = None, **kwargs: Any, ) -> Dict[str, Any]: payload: Dict[str, Any] = dict(inputs or {}) payload.update(kwargs) if use_llm_extractor is not None and bool(use_llm_extractor) != self.config.use_llm_extractor: # Rebuild extractor with the requested setting for this runtime instance. self.config.use_llm_extractor = bool(use_llm_extractor) self._extractor = None prediction_length = int(prediction_length or self.config.prediction_length) quantile_levels = list(quantile_levels or self.config.quantile_levels) allow_naive_fallback = self.config.allow_naive_fallback if allow_naive_fallback is None else bool(allow_naive_fallback) news = payload.get("news") stock = payload.get("stock") future_news = payload.get("future_news") future_covariates = payload.get("future_covariates") has_news = bool(news) stock_df = self._prepare_stock_df(stock) has_stock = stock_df is not None and not stock_df.empty if has_news and has_stock: context_df = self._merge_news_covariates(stock_df, news) future_df = self._prepare_future_covariates( stock_df=context_df, future_news=future_news, future_covariates=future_covariates, prediction_length=prediction_length, ) return self._forecast( context_df=context_df, prediction_length=prediction_length, quantile_levels=quantile_levels, route="hybrid", future_df=future_df, allow_naive_fallback=allow_naive_fallback, ) if has_stock: return self._forecast( context_df=stock_df, prediction_length=prediction_length, quantile_levels=quantile_levels, route="chronos_only", future_df=None, allow_naive_fallback=allow_naive_fallback, ) if has_news: events = [self.extractor.extract(item.get("title") or item.get("headline") or item.get("text") or "") for item in news] daily_covariates = self.extractor.aggregate_to_daily(news) return { "route": "text_only", "repo_id": self.config.repo_id, "events": events, "daily_covariates": self._df_to_records(daily_covariates), } raise ValueError("JNU-TSB expects at least one of: stock, news.") def _forecast( self, context_df: pd.DataFrame, prediction_length: int, quantile_levels: Sequence[float], route: str, future_df: Optional[pd.DataFrame] = None, allow_naive_fallback: bool = True, ) -> Dict[str, Any]: try: kwargs = dict( prediction_length=prediction_length, quantile_levels=list(quantile_levels), id_column=self.config.id_column, timestamp_column=self.config.timestamp_column, target=self.config.target_column, ) if future_df is not None and not future_df.empty: pred = self.chronos.predict_df(context_df, future_df=future_df, **kwargs) else: pred = self.chronos.predict_df(context_df, **kwargs) return { "route": route, "repo_id": self.config.repo_id, "engine": self.config.chronos_model_id, "forecast": self._df_to_records(pred), "used_naive_fallback": False, } except Exception as exc: if not allow_naive_fallback: raise pred = self._naive_forecast(context_df, prediction_length, quantile_levels) return { "route": route, "repo_id": self.config.repo_id, "engine": "naive_last_value_fallback", "forecast": self._df_to_records(pred), "used_naive_fallback": True, "warning": f"Chronos-2 inference failed or was unavailable: {type(exc).__name__}: {exc}", } def _prepare_stock_df(self, stock: Any) -> Optional[pd.DataFrame]: if stock is None: return None if isinstance(stock, pd.DataFrame): df = stock.copy() elif isinstance(stock, list): df = pd.DataFrame(stock) elif isinstance(stock, dict): df = pd.DataFrame(stock) else: raise TypeError("stock must be a pandas DataFrame, list of dicts, or dict of columns.") if df.empty: return df timestamp_col = self.config.timestamp_column if timestamp_col not in df.columns: for cand in ("date", "Date", "datetime", "time"): if cand in df.columns: df = df.rename(columns={cand: timestamp_col}) break target_col = self.config.target_column if target_col not in df.columns: for cand in ("close", "Close", "price", "value", "y"): if cand in df.columns: df = df.rename(columns={cand: target_col}) break if timestamp_col not in df.columns or target_col not in df.columns: raise ValueError(f"stock must contain '{timestamp_col}' and '{target_col}' columns.") if self.config.id_column not in df.columns: df[self.config.id_column] = self.config.default_item_id df[timestamp_col] = pd.to_datetime(df[timestamp_col]) df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True) return df def _prepare_future_df(self, data: Any) -> Optional[pd.DataFrame]: if data is None: return None if isinstance(data, pd.DataFrame): df = data.copy() elif isinstance(data, list): df = pd.DataFrame(data) elif isinstance(data, dict): df = pd.DataFrame(data) else: raise TypeError("future_covariates must be a pandas DataFrame, list of dicts, or dict of columns.") if df.empty: return df timestamp_col = self.config.timestamp_column if timestamp_col not in df.columns: for cand in ("date", "Date", "datetime", "time"): if cand in df.columns: df = df.rename(columns={cand: timestamp_col}) break if timestamp_col not in df.columns: raise ValueError(f"future_covariates must contain a '{timestamp_col}' column.") if self.config.id_column not in df.columns: df[self.config.id_column] = self.config.default_item_id df[timestamp_col] = pd.to_datetime(df[timestamp_col]) df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True) return df def _merge_news_covariates(self, stock_df: pd.DataFrame, news: Iterable[Dict[str, Any]]) -> pd.DataFrame: cov = self.extractor.aggregate_to_daily(news) context = stock_df.copy() day_col = "__day__" context[day_col] = pd.to_datetime(context[self.config.timestamp_column]).dt.floor("D") cov = cov.rename(columns={"timestamp": day_col}) merged = context.merge(cov, on=day_col, how="left").drop(columns=[day_col]) for col in COVARIATE_COLUMNS: if col in merged.columns: merged[col] = merged[col].fillna(0).astype(float) return merged def _prepare_future_covariates( self, stock_df: pd.DataFrame, future_news: Optional[Iterable[Dict[str, Any]]], future_covariates: Any, prediction_length: int, ) -> Optional[pd.DataFrame]: if future_covariates is not None: fut = self._prepare_future_df(future_covariates) if fut is not None and not fut.empty: return fut.drop(columns=[self.config.target_column], errors="ignore") if not future_news: return None first_id = stock_df[self.config.id_column].iloc[0] last_ts = pd.to_datetime(stock_df[self.config.timestamp_column]).max() freq = pd.infer_freq(pd.to_datetime(stock_df[self.config.timestamp_column]).drop_duplicates().sort_values()) or "D" future_dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:] base = pd.DataFrame({ self.config.timestamp_column: future_dates, self.config.id_column: first_id, }) cov = self.extractor.aggregate_to_daily(future_news) if cov.empty: return base cov_day = cov.rename(columns={"timestamp": "__day__"}) base["__day__"] = pd.to_datetime(base[self.config.timestamp_column]).dt.floor("D") merged = base.merge(cov_day, on="__day__", how="left").drop(columns=["__day__"]) for col in COVARIATE_COLUMNS: if col in merged.columns: merged[col] = merged[col].fillna(0).astype(float) return merged def _naive_forecast(self, context_df: pd.DataFrame, prediction_length: int, quantile_levels: Sequence[float]) -> pd.DataFrame: timestamp_col = self.config.timestamp_column target_col = self.config.target_column id_col = self.config.id_column rows: List[Dict[str, Any]] = [] for item_id, group in context_df.groupby(id_col): group = group.sort_values(timestamp_col) last_ts = pd.to_datetime(group[timestamp_col].iloc[-1]) last_value = float(group[target_col].iloc[-1]) freq = pd.infer_freq(pd.to_datetime(group[timestamp_col]).drop_duplicates().sort_values()) or "D" dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:] for ts in dates: row: Dict[str, Any] = {id_col: item_id, timestamp_col: ts} for q in quantile_levels: row[str(q)] = last_value row[f"q{q}"] = last_value row["mean"] = last_value row["prediction"] = last_value rows.append(row) return pd.DataFrame(rows) def _df_to_records(self, df: pd.DataFrame) -> List[Dict[str, Any]]: out = df.copy() for col in out.columns: if pd.api.types.is_datetime64_any_dtype(out[col]): out[col] = out[col].astype(str) return out.to_dict(orient="records")