Time Series Forecasting
Transformers
PyTorch
Korean
jnu_tsb
feature-extraction
jnu-tsb
time-series
forecasting
chronos-2
polyglot-ko
korean
finance
covariates
r
reticulate
education
custom_code
Instructions to use HONGRIZON/JNU-TSB with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use HONGRIZON/JNU-TSB with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("HONGRIZON/JNU-TSB", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Optional, Sequence, Union | |
| import pandas as pd | |
| import torch | |
| try: | |
| from .configuration_jnu_tsb import JNUTSBConfig | |
| from .event_extractor import COVARIATE_COLUMNS, EventExtractor | |
| except ImportError: # pragma: no cover - local execution fallback | |
| from configuration_jnu_tsb import JNUTSBConfig | |
| from event_extractor import COVARIATE_COLUMNS, EventExtractor | |
| class JNUTSBRuntime: | |
| """Runtime used by the model wrapper, pipeline, Endpoint handler, Gradio, and R examples. | |
| Routes inputs into three paths: | |
| 1. stock only -> Chronos-2 forecast | |
| 2. news only -> event extraction and daily covariates | |
| 3. stock + news -> news covariates + stock context -> Chronos-2 forecast | |
| """ | |
| def __init__( | |
| self, | |
| config: Union[JNUTSBConfig, Dict[str, Any]], | |
| chronos_device_map: Optional[str] = None, | |
| llm_device_map: Optional[str] = None, | |
| ) -> None: | |
| if isinstance(config, dict): | |
| config = JNUTSBConfig(**config) | |
| self.config = config | |
| self.chronos_device_map = chronos_device_map or os.getenv("JNU_TSB_CHRONOS_DEVICE_MAP", "cpu") | |
| self.llm_device_map = llm_device_map or os.getenv("JNU_TSB_LLM_DEVICE_MAP", "cpu") | |
| self._chronos = None | |
| self._llm_pipe = None | |
| self._extractor = None | |
| def from_config(cls, config: Union[JNUTSBConfig, Dict[str, Any]], **kwargs: Any) -> "JNUTSBRuntime": | |
| return cls(config=config, **kwargs) | |
| def from_config_dir(cls, path: Union[str, os.PathLike[str]], **kwargs: Any) -> "JNUTSBRuntime": | |
| path = Path(path) | |
| with open(path / "config.json", "r", encoding="utf-8") as f: | |
| payload = json.load(f) | |
| return cls(config=payload, **kwargs) | |
| def chronos(self): | |
| if self._chronos is None: | |
| try: | |
| from chronos import Chronos2Pipeline | |
| except Exception as exc: # pragma: no cover | |
| raise ImportError( | |
| "chronos-forecasting is required for Chronos-2 inference. " | |
| "Install it with: pip install chronos-forecasting" | |
| ) from exc | |
| self._chronos = Chronos2Pipeline.from_pretrained( | |
| self.config.chronos_model_id, | |
| device_map=self.chronos_device_map, | |
| ) | |
| return self._chronos | |
| def extractor(self) -> EventExtractor: | |
| if self._extractor is None: | |
| self._extractor = EventExtractor( | |
| generate_fn=self._generate_with_polyglot if self.config.use_llm_extractor else None, | |
| categories=self.config.event_categories, | |
| use_llm=self.config.use_llm_extractor, | |
| ) | |
| return self._extractor | |
| def _generate_with_polyglot(self, prompt: str) -> str: | |
| if self._llm_pipe is None: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline as hf_pipeline | |
| tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| self.config.llm_model_id, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map=self.llm_device_map, | |
| ) | |
| self._llm_pipe = hf_pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| ) | |
| output = self._llm_pipe( | |
| prompt, | |
| max_new_tokens=96, | |
| do_sample=False, | |
| return_full_text=False, | |
| ) | |
| if isinstance(output, list) and output: | |
| return output[0].get("generated_text", "") | |
| return str(output) | |
| def predict( | |
| self, | |
| inputs: Optional[Dict[str, Any]] = None, | |
| prediction_length: Optional[int] = None, | |
| quantile_levels: Optional[Sequence[float]] = None, | |
| use_llm_extractor: Optional[bool] = None, | |
| allow_naive_fallback: Optional[bool] = None, | |
| **kwargs: Any, | |
| ) -> Dict[str, Any]: | |
| payload: Dict[str, Any] = dict(inputs or {}) | |
| payload.update(kwargs) | |
| if use_llm_extractor is not None and bool(use_llm_extractor) != self.config.use_llm_extractor: | |
| # Rebuild extractor with the requested setting for this runtime instance. | |
| self.config.use_llm_extractor = bool(use_llm_extractor) | |
| self._extractor = None | |
| prediction_length = int(prediction_length or self.config.prediction_length) | |
| quantile_levels = list(quantile_levels or self.config.quantile_levels) | |
| allow_naive_fallback = self.config.allow_naive_fallback if allow_naive_fallback is None else bool(allow_naive_fallback) | |
| news = payload.get("news") | |
| stock = payload.get("stock") | |
| future_news = payload.get("future_news") | |
| future_covariates = payload.get("future_covariates") | |
| has_news = bool(news) | |
| stock_df = self._prepare_stock_df(stock) | |
| has_stock = stock_df is not None and not stock_df.empty | |
| if has_news and has_stock: | |
| context_df = self._merge_news_covariates(stock_df, news) | |
| future_df = self._prepare_future_covariates( | |
| stock_df=context_df, | |
| future_news=future_news, | |
| future_covariates=future_covariates, | |
| prediction_length=prediction_length, | |
| ) | |
| return self._forecast( | |
| context_df=context_df, | |
| prediction_length=prediction_length, | |
| quantile_levels=quantile_levels, | |
| route="hybrid", | |
| future_df=future_df, | |
| allow_naive_fallback=allow_naive_fallback, | |
| ) | |
| if has_stock: | |
| return self._forecast( | |
| context_df=stock_df, | |
| prediction_length=prediction_length, | |
| quantile_levels=quantile_levels, | |
| route="chronos_only", | |
| future_df=None, | |
| allow_naive_fallback=allow_naive_fallback, | |
| ) | |
| if has_news: | |
| events = [self.extractor.extract(item.get("title") or item.get("headline") or item.get("text") or "") for item in news] | |
| daily_covariates = self.extractor.aggregate_to_daily(news) | |
| return { | |
| "route": "text_only", | |
| "repo_id": self.config.repo_id, | |
| "events": events, | |
| "daily_covariates": self._df_to_records(daily_covariates), | |
| } | |
| raise ValueError("JNU-TSB expects at least one of: stock, news.") | |
| def _forecast( | |
| self, | |
| context_df: pd.DataFrame, | |
| prediction_length: int, | |
| quantile_levels: Sequence[float], | |
| route: str, | |
| future_df: Optional[pd.DataFrame] = None, | |
| allow_naive_fallback: bool = True, | |
| ) -> Dict[str, Any]: | |
| try: | |
| kwargs = dict( | |
| prediction_length=prediction_length, | |
| quantile_levels=list(quantile_levels), | |
| id_column=self.config.id_column, | |
| timestamp_column=self.config.timestamp_column, | |
| target=self.config.target_column, | |
| ) | |
| if future_df is not None and not future_df.empty: | |
| pred = self.chronos.predict_df(context_df, future_df=future_df, **kwargs) | |
| else: | |
| pred = self.chronos.predict_df(context_df, **kwargs) | |
| return { | |
| "route": route, | |
| "repo_id": self.config.repo_id, | |
| "engine": self.config.chronos_model_id, | |
| "forecast": self._df_to_records(pred), | |
| "used_naive_fallback": False, | |
| } | |
| except Exception as exc: | |
| if not allow_naive_fallback: | |
| raise | |
| pred = self._naive_forecast(context_df, prediction_length, quantile_levels) | |
| return { | |
| "route": route, | |
| "repo_id": self.config.repo_id, | |
| "engine": "naive_last_value_fallback", | |
| "forecast": self._df_to_records(pred), | |
| "used_naive_fallback": True, | |
| "warning": f"Chronos-2 inference failed or was unavailable: {type(exc).__name__}: {exc}", | |
| } | |
| def _prepare_stock_df(self, stock: Any) -> Optional[pd.DataFrame]: | |
| if stock is None: | |
| return None | |
| if isinstance(stock, pd.DataFrame): | |
| df = stock.copy() | |
| elif isinstance(stock, list): | |
| df = pd.DataFrame(stock) | |
| elif isinstance(stock, dict): | |
| df = pd.DataFrame(stock) | |
| else: | |
| raise TypeError("stock must be a pandas DataFrame, list of dicts, or dict of columns.") | |
| if df.empty: | |
| return df | |
| timestamp_col = self.config.timestamp_column | |
| if timestamp_col not in df.columns: | |
| for cand in ("date", "Date", "datetime", "time"): | |
| if cand in df.columns: | |
| df = df.rename(columns={cand: timestamp_col}) | |
| break | |
| target_col = self.config.target_column | |
| if target_col not in df.columns: | |
| for cand in ("close", "Close", "price", "value", "y"): | |
| if cand in df.columns: | |
| df = df.rename(columns={cand: target_col}) | |
| break | |
| if timestamp_col not in df.columns or target_col not in df.columns: | |
| raise ValueError(f"stock must contain '{timestamp_col}' and '{target_col}' columns.") | |
| if self.config.id_column not in df.columns: | |
| df[self.config.id_column] = self.config.default_item_id | |
| df[timestamp_col] = pd.to_datetime(df[timestamp_col]) | |
| df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True) | |
| return df | |
| def _prepare_future_df(self, data: Any) -> Optional[pd.DataFrame]: | |
| if data is None: | |
| return None | |
| if isinstance(data, pd.DataFrame): | |
| df = data.copy() | |
| elif isinstance(data, list): | |
| df = pd.DataFrame(data) | |
| elif isinstance(data, dict): | |
| df = pd.DataFrame(data) | |
| else: | |
| raise TypeError("future_covariates must be a pandas DataFrame, list of dicts, or dict of columns.") | |
| if df.empty: | |
| return df | |
| timestamp_col = self.config.timestamp_column | |
| if timestamp_col not in df.columns: | |
| for cand in ("date", "Date", "datetime", "time"): | |
| if cand in df.columns: | |
| df = df.rename(columns={cand: timestamp_col}) | |
| break | |
| if timestamp_col not in df.columns: | |
| raise ValueError(f"future_covariates must contain a '{timestamp_col}' column.") | |
| if self.config.id_column not in df.columns: | |
| df[self.config.id_column] = self.config.default_item_id | |
| df[timestamp_col] = pd.to_datetime(df[timestamp_col]) | |
| df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True) | |
| return df | |
| def _merge_news_covariates(self, stock_df: pd.DataFrame, news: Iterable[Dict[str, Any]]) -> pd.DataFrame: | |
| cov = self.extractor.aggregate_to_daily(news) | |
| context = stock_df.copy() | |
| day_col = "__day__" | |
| context[day_col] = pd.to_datetime(context[self.config.timestamp_column]).dt.floor("D") | |
| cov = cov.rename(columns={"timestamp": day_col}) | |
| merged = context.merge(cov, on=day_col, how="left").drop(columns=[day_col]) | |
| for col in COVARIATE_COLUMNS: | |
| if col in merged.columns: | |
| merged[col] = merged[col].fillna(0).astype(float) | |
| return merged | |
| def _prepare_future_covariates( | |
| self, | |
| stock_df: pd.DataFrame, | |
| future_news: Optional[Iterable[Dict[str, Any]]], | |
| future_covariates: Any, | |
| prediction_length: int, | |
| ) -> Optional[pd.DataFrame]: | |
| if future_covariates is not None: | |
| fut = self._prepare_future_df(future_covariates) | |
| if fut is not None and not fut.empty: | |
| return fut.drop(columns=[self.config.target_column], errors="ignore") | |
| if not future_news: | |
| return None | |
| first_id = stock_df[self.config.id_column].iloc[0] | |
| last_ts = pd.to_datetime(stock_df[self.config.timestamp_column]).max() | |
| freq = pd.infer_freq(pd.to_datetime(stock_df[self.config.timestamp_column]).drop_duplicates().sort_values()) or "D" | |
| future_dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:] | |
| base = pd.DataFrame({ | |
| self.config.timestamp_column: future_dates, | |
| self.config.id_column: first_id, | |
| }) | |
| cov = self.extractor.aggregate_to_daily(future_news) | |
| if cov.empty: | |
| return base | |
| cov_day = cov.rename(columns={"timestamp": "__day__"}) | |
| base["__day__"] = pd.to_datetime(base[self.config.timestamp_column]).dt.floor("D") | |
| merged = base.merge(cov_day, on="__day__", how="left").drop(columns=["__day__"]) | |
| for col in COVARIATE_COLUMNS: | |
| if col in merged.columns: | |
| merged[col] = merged[col].fillna(0).astype(float) | |
| return merged | |
| def _naive_forecast(self, context_df: pd.DataFrame, prediction_length: int, quantile_levels: Sequence[float]) -> pd.DataFrame: | |
| timestamp_col = self.config.timestamp_column | |
| target_col = self.config.target_column | |
| id_col = self.config.id_column | |
| rows: List[Dict[str, Any]] = [] | |
| for item_id, group in context_df.groupby(id_col): | |
| group = group.sort_values(timestamp_col) | |
| last_ts = pd.to_datetime(group[timestamp_col].iloc[-1]) | |
| last_value = float(group[target_col].iloc[-1]) | |
| freq = pd.infer_freq(pd.to_datetime(group[timestamp_col]).drop_duplicates().sort_values()) or "D" | |
| dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:] | |
| for ts in dates: | |
| row: Dict[str, Any] = {id_col: item_id, timestamp_col: ts} | |
| for q in quantile_levels: | |
| row[str(q)] = last_value | |
| row[f"q{q}"] = last_value | |
| row["mean"] = last_value | |
| row["prediction"] = last_value | |
| rows.append(row) | |
| return pd.DataFrame(rows) | |
| def _df_to_records(self, df: pd.DataFrame) -> List[Dict[str, Any]]: | |
| out = df.copy() | |
| for col in out.columns: | |
| if pd.api.types.is_datetime64_any_dtype(out[col]): | |
| out[col] = out[col].astype(str) | |
| return out.to_dict(orient="records") | |