| import os |
| import datetime as dt |
| import pandas as pd |
| import torch |
| import gradio as gr |
| import yfinance as yf |
|
|
| from chronos import BaseChronosPipeline |
|
|
| |
| _PIPELINE_CACHE = {} |
|
|
| def get_pipeline(model_id: str, device: str = "cpu"): |
| key = (model_id, device) |
| if key not in _PIPELINE_CACHE: |
| _PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained( |
| model_id, |
| device_map=device, |
| torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16, |
| ) |
| return _PIPELINE_CACHE[key] |
|
|
| |
| def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"): |
| |
| df = yf.download(ticker, start=start, end=end, interval=interval, progress=False) |
| if df.empty or "Close" not in df: |
| raise ValueError("๋ฐ์ดํฐ๊ฐ ์๊ฑฐ๋ 'Close' ์ด์ ์ฐพ์ ์ ์์ต๋๋ค. ํฐ์ปค/๋ ์ง๋ฅผ ํ์ธํ์ธ์.") |
| s = df["Close"].dropna().astype(float) |
| return s |
|
|
| |
| def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval): |
| try: |
| series = load_close_series(ticker, start_date, end_date, interval) |
| except Exception as e: |
| return gr.Plot.update(None), pd.DataFrame(), f"๋ฐ์ดํฐ ๋ก๋ฉ ์ค๋ฅ: {e}" |
|
|
| pipe = get_pipeline(model_id, device) |
| H = int(horizon) |
|
|
| |
| context = torch.tensor(series.values, dtype=torch.float32) |
|
|
| |
| |
| preds = pipe.predict(context=context, prediction_length=H)[0] |
| q10, q50, q90 = preds[0], preds[1], preds[2] |
|
|
| |
| df_fcst = pd.DataFrame( |
| {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()}, |
| index=pd.RangeIndex(1, H + 1, name="step"), |
| ) |
|
|
| |
| import matplotlib.pyplot as plt |
| fig = plt.figure(figsize=(10, 4)) |
| plt.plot(series.index, series.values, label="history") |
| |
| future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:] |
| plt.plot(future_index, q50.numpy(), label="forecast(q50)") |
| plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โq90") |
| plt.title(f"{ticker} forecast by Chronos-Bolt") |
| plt.legend() |
| plt.tight_layout() |
|
|
| note = "โป ๋ฐ๋ชจ ๋ชฉ์ ์
๋๋ค. ํฌ์ ํ๋จ์ ์ฑ
์์ ๋ณธ์ธ์๊ฒ ์์ต๋๋ค." |
| return fig, df_fcst, note |
|
|
| |
| with gr.Blocks(title="Chronos Stock Forecast") as demo: |
| gr.Markdown("# Chronos ์ฃผ๊ฐ ์์ธก ๋ฐ๋ชจ") |
| with gr.Row(): |
| ticker = gr.Textbox(value="AAPL", label="ํฐ์ปค (์: AAPL, MSFT, 005930.KS)") |
| horizon = gr.Slider(5, 60, value=20, step=1, label="์์ธก ๊ธธ์ด H (์ผ)") |
| with gr.Row(): |
| start = gr.Textbox(value=(dt.date.today()-dt.timedelta(days=365)).isoformat(), label="์์์ผ (YYYY-MM-DD)") |
| end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข
๋ฃ์ผ (YYYY-MM-DD)") |
| with gr.Row(): |
| model_id = gr.Dropdown( |
| choices=[ |
| "amazon/chronos-bolt-tiny", |
| "amazon/chronos-bolt-mini", |
| "amazon/chronos-bolt-small", |
| "amazon/chronos-bolt-base", |
| ], |
| value="amazon/chronos-bolt-small", |
| label="๋ชจ๋ธ" |
| ) |
| device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋๋ฐ์ด์ค") |
| interval = gr.Dropdown(choices=["1d"], value="1d", label="๊ฐ๊ฒฉ") |
| btn = gr.Button("์์ธก ์คํ") |
|
|
| plot = gr.Plot(label="History + Forecast") |
| table = gr.Dataframe(label="์์ธก ๊ฒฐ๊ณผ (๋ถ์์)") |
| note = gr.Markdown() |
|
|
| btn.click( |
| fn=run_forecast, |
| inputs=[ticker, start, end, horizon, model_id, device, interval], |
| outputs=[plot, table, note] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|