File size: 5,239 Bytes
d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b 0b929da d24798b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import datetime as dt
import pandas as pd
import torch
import gradio as gr
import yfinance as yf
from chronos import BaseChronosPipeline # from 'chronos-forecasting'
# ---- ์ ์ญ ์บ์: ๋ชจ๋ธ์ ํ ๋ฒ๋ง ๋ก๋ํด ์ฌ์ฌ์ฉ ----
_PIPELINE_CACHE = {}
def get_pipeline(model_id: str, device: str = "cpu"):
key = (model_id, device)
if key not in _PIPELINE_CACHE:
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
model_id,
device_map=device, # "cpu" / "cuda"
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
)
return _PIPELINE_CACHE[key]
# ---- ์ฃผ๊ฐ/ํฌ๋ฆฝํ ๋ฐ์ดํฐ ๋ก๋ฉ (yfinance, ๊ฒฌ๊ณ ํ) ----
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
"""
BTC-USD ๋ฑ ํฌ๋ฆฝํ ์ฌ๋ณผ์์ ๊ฐํ์ ์ผ๋ก timezone/ํ์ฑ ์ค๋ฅ๊ฐ ๋๋ฏ๋ก
history() ๊ฒฝ๋ก๋ฅผ ์ฐ์ ์ฌ์ฉํ๊ณ , ์คํจ ์ ํ ๋ฒ ์ฌ์๋.
"""
# ๊ธฐ๋ณธ๊ฐ ๋ณด์ : ๋๋ฌด ์ต๊ทผ๋ง ๊ณ ๋ฅด๋ฉด ๋น ๋ฐ์ดํฐ๊ฐ ๋์ฌ ์ ์์ด ์ผ๋ด์ ๊ณผ๊ฑฐ๋ถํฐ ๊ถ์ฅ
_start = start or "2014-09-17" # BTC-USD ์ต์ด ์์ฅ์ผ ๊ทผ์ฒ
_end = end or dt.date.today().isoformat()
tk = yf.Ticker(ticker)
try:
df = tk.history(start=_start, end=_end, interval=interval, auto_adjust=True, actions=False)
if df.empty or "Close" not in df:
raise ValueError("empty")
except Exception:
# fallback: download() ๊ฒฝ๋ก ์๋
df = yf.download(ticker, start=_start, end=_end, interval=interval, progress=False, threads=False)
if df.empty or "Close" not in df:
raise ValueError("๋ฐ์ดํฐ๊ฐ ์๊ฑฐ๋ 'Close' ์ด์ด ์์ต๋๋ค. ํฐ์ปค/๋ ์ง/๊ฐ๊ฒฉ์ ํ์ธํ์ธ์.")
s = df["Close"].dropna().astype(float)
if s.empty:
raise ValueError("๋ค์ด๋ก๋ ๊ฒฐ๊ณผ๊ฐ ๋น์ด ์์ต๋๋ค. ๊ธฐ๊ฐ/๊ฐ๊ฒฉ์ ์ค์ด๊ฑฐ๋ ๋ค์ ์๋ํ์ธ์.")
return s
# ---- ์์ธก ํจ์ (Gradio๊ฐ ํธ์ถ) ----
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
try:
series = load_close_series(ticker.strip(), start_date, end_date, interval)
except Exception as e:
# Gradio v4์์๋ Plot.update๊ฐ ์์ โ None ๋ฐํ์ผ๋ก ์ ๋ฆฌ
return None, pd.DataFrame(), f"๋ฐ์ดํฐ ๋ก๋ฉ ์ค๋ฅ: {e}"
pipe = get_pipeline(model_id, device)
H = int(horizon)
# Chronos ์
๋ ฅ: 1D ํ
์ (float)
context = torch.tensor(series.values, dtype=torch.float32)
# ์์ธก: (num_series=1, num_quantiles=3, H) with q=[0.1, 0.5, 0.9]
preds = pipe.predict(context=context, prediction_length=H)[0]
q10, q50, q90 = preds[0], preds[1], preds[2]
# ํ ๋ฐ์ดํฐ
df_fcst = pd.DataFrame(
{"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
index=pd.RangeIndex(1, H + 1, name="step"),
)
# ๋ฏธ๋ x์ถ: interval์ ๋ง๋ pandas ์ฃผ๊ธฐ
import matplotlib.pyplot as plt
freq_map = {"1d": "D", "1h": "H", "30m": "30T", "15m": "15T", "5m": "5T"}
freq = freq_map.get(interval, "D")
future_index = pd.date_range(series.index[-1], periods=H + 1, freq=freq)[1:]
# ๊ทธ๋ํ
fig = plt.figure(figsize=(10, 4))
plt.plot(series.index, series.values, label="history")
plt.plot(future_index, q50.numpy(), label="forecast(q50)")
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โq90")
plt.title(f"{ticker} forecast by Chronos-Bolt ({interval}, H={H})")
plt.legend()
plt.tight_layout()
note = "โป ๋ฐ๋ชจ ๋ชฉ์ ์
๋๋ค. ํฌ์ ํ๋จ์ ์ฑ
์์ ๋ณธ์ธ์๊ฒ ์์ต๋๋ค."
return fig, df_fcst, note
# ---- Gradio UI ----
with gr.Blocks(title="Chronos Stock/Crypto Forecast") as demo:
gr.Markdown("# Chronos ์ฃผ๊ฐยทํฌ๋ฆฝํ ์์ธก ๋ฐ๋ชจ")
with gr.Row():
ticker = gr.Textbox(value="BTC-USD", label="ํฐ์ปค (์: AAPL, MSFT, 005930.KS, BTC-USD)")
horizon = gr.Slider(5, 365, value=90, step=1, label="์์ธก ์คํ
H (๊ฐ๊ฒฉ ๋จ์์ ๋์ผ)")
with gr.Row():
start = gr.Textbox(value="2014-09-17", label="์์์ผ (YYYY-MM-DD, ์: 2014-09-17)")
end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข
๋ฃ์ผ (YYYY-MM-DD, ๋น์๋๋ฉด ์ค๋)")
with gr.Row():
model_id = gr.Dropdown(
choices=[
"amazon/chronos-bolt-tiny",
"amazon/chronos-bolt-mini",
"amazon/chronos-bolt-small",
"amazon/chronos-bolt-base",
],
value="amazon/chronos-bolt-small",
label="๋ชจ๋ธ"
)
device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋๋ฐ์ด์ค")
interval = gr.Dropdown(
choices=["1d", "1h", "30m", "15m", "5m"],
value="1d",
label="๊ฐ๊ฒฉ"
)
btn = gr.Button("์์ธก ์คํ")
plot = gr.Plot(label="History + Forecast")
table = gr.Dataframe(label="์์ธก ๊ฒฐ๊ณผ (๋ถ์์)")
note = gr.Markdown()
btn.click(
fn=run_forecast,
inputs=[ticker, start, end, horizon, model_id, device, interval],
outputs=[plot, table, note]
)
if __name__ == "__main__":
demo.launch()
|