TimeFM-forecast / app_timefm.py
Jeongryeol's picture
Upload 5 files
2c6e8f4 verified
"""
종편 4사 시청률 예측 모델 — Google TimeFM 버전
"""
import streamlit as st
import pandas as pd
import plotly.graph_objects as go
st.set_page_config(
page_title="종편 4사 시청률 예측 · TimeFM",
page_icon="▲",
layout="wide",
initial_sidebar_state="expanded",
)
# ── CSS ──────────────────────────────────────────────────────
st.markdown("""
<style>
/* 전체 배경 */
[data-testid="stAppViewContainer"] { background: #141414; }
[data-testid="stSidebar"] { background: #1a1a1a; border-right: 1px solid #2a2a2a; }
[data-testid="stHeader"] { background: #141414; }
/* 기본 텍스트 */
html, body, [class*="css"] { color: #e8e8e8; }
/* 사이드바 타이틀 */
.sidebar-title {
font-size: 0.7rem;
font-weight: 600;
letter-spacing: 0.12em;
color: #ccc;
text-transform: uppercase;
margin-bottom: 1rem;
padding-bottom: 0.5rem;
border-bottom: 1px solid #333;
}
/* 채널 카드 — 좌측 컬러 보더 */
.ch-card {
background: #1e1e1e;
border-left: 3px solid var(--ch-color);
border-radius: 0 8px 8px 0;
padding: 1rem 1.2rem;
margin-bottom: 0;
}
.ch-name {
font-size: 0.75rem;
font-weight: 600;
letter-spacing: 0.05em;
color: var(--ch-color);
text-transform: uppercase;
margin-bottom: 0.4rem;
}
.ch-value {
font-size: 1.9rem;
font-weight: 700;
color: #ffffff;
letter-spacing: -0.02em;
line-height: 1;
}
.ch-unit { font-size: 1rem; color: #bbb; margin-left: 2px; }
.ch-ci {
font-size: 0.72rem;
color: #999;
margin-top: 0.5rem;
line-height: 1.6;
}
.ch-ci span { color: #bbb; }
/* 섹션 헤더 */
.section-label {
font-size: 0.7rem;
font-weight: 600;
letter-spacing: 0.1em;
color: #ccc;
text-transform: uppercase;
margin-bottom: 1rem;
}
/* 탭 */
[data-testid="stTabs"] button {
font-size: 0.8rem;
letter-spacing: 0.04em;
color: #888 !important;
border-bottom: 2px solid transparent !important;
}
[data-testid="stTabs"] button[aria-selected="true"] {
color: #e8e8e8 !important;
border-bottom: 2px solid #e8e8e8 !important;
}
/* 버튼 */
[data-testid="stButton"] > button[kind="primary"] {
background: #e8e8e8;
color: #141414;
border: none;
border-radius: 6px;
font-weight: 600;
font-size: 0.82rem;
letter-spacing: 0.03em;
padding: 0.5rem 1.2rem;
}
[data-testid="stButton"] > button[kind="primary"]:hover {
background: #ffffff;
}
/* 인풋 */
[data-testid="stTextInput"] input,
[data-testid="stSelectbox"] > div {
background: #1e1e1e !important;
border: 1px solid #2a2a2a !important;
color: #e8e8e8 !important;
border-radius: 6px !important;
font-size: 0.82rem !important;
}
/* 인풋 레이블 */
[data-testid="stTextInput"] label,
[data-testid="stSelectbox"] label {
color: #ccc !important;
font-size: 0.78rem !important;
}
/* 힌트 텍스트 */
.help-text {
font-size: 0.7rem;
color: #888;
line-height: 1.6;
margin-top: 0.3rem;
}
/* 구분선 */
hr { border-color: #2a2a2a; }
/* 데이터프레임 */
[data-testid="stDataFrame"] { background: #1a1a1a; }
/* 분석 로그 (st.status) 텍스트 밝기 */
[data-testid="stStatusWidget"] p,
[data-testid="stStatusWidget"] span,
[data-testid="stStatusWidget"] div,
[data-testid="stExpander"] [data-testid="stMarkdownContainer"] p,
[data-testid="stExpander"] [data-testid="stMarkdownContainer"] span {
color: #e8e8e8 !important;
font-size: 0.82rem !important;
line-height: 1.8 !important;
}
/* ── 모바일 반응형 ─────────────────────────────────── */
@media (max-width: 768px) {
/* 채널 카드: 4열 → 2열 */
[data-testid="column"] {
width: 50% !important;
flex: 0 0 50% !important;
min-width: 50% !important;
}
/* 카드 내 값 폰트 축소 */
.ch-value { font-size: 1.5rem !important; }
.ch-ci { font-size: 0.75rem !important; }
.ch-name { font-size: 0.72rem !important; }
/* 탭 버튼 크기 확대 */
[data-testid="stTabs"] button {
font-size: 0.85rem !important;
padding: 0.5rem 0.8rem !important;
}
/* 여백 축소 */
[data-testid="stAppViewContainer"] > section { padding: 0.5rem !important; }
}
@media (max-width: 480px) {
/* 채널 카드: 2열 → 1열 */
[data-testid="column"] {
width: 100% !important;
flex: 0 0 100% !important;
min-width: 100% !important;
}
.ch-value { font-size: 1.8rem !important; }
.ch-ci { font-size: 0.8rem !important; }
}
</style>
""", unsafe_allow_html=True)
COLORS = {
"News_A": "#3B82F6",
"JTBC": "#A855F7",
"MBN": "#F59E0B",
"TVCHOSUN": "#EF4444",
}
LABELS = {
"News_A": "뉴스A",
"JTBC": "JTBC뉴스룸",
"MBN": "MBN뉴스7",
"TVCHOSUN": "TV조선뉴스9",
}
ORDER = ["News_A", "JTBC", "MBN", "TVCHOSUN"]
def hex_to_rgba(hex_color, alpha):
r = int(hex_color[1:3], 16)
g = int(hex_color[3:5], 16)
b = int(hex_color[5:7], 16)
return f"rgba({r},{g},{b},{alpha})"
CHART_LAYOUT = dict(
template="plotly_dark",
paper_bgcolor="#141414",
plot_bgcolor="#141414",
font=dict(family="Inter, -apple-system, sans-serif", color="#bbb", size=11),
xaxis=dict(gridcolor="#222", zeroline=False, showline=False),
yaxis=dict(gridcolor="#222", zeroline=False, showline=False),
hovermode="x unified",
hoverlabel=dict(bgcolor="#1e1e1e", bordercolor="#333", font_color="#e8e8e8"),
margin=dict(l=0, r=0, t=36, b=0),
)
# ── 모델 캐시 ─────────────────────────────────────────────────
@st.cache_resource(show_spinner=False)
def get_forecaster(sheets_id, gid, predict_days, _log_fn=None):
from timefm_forecaster import TimeFMForecaster
fc = TimeFMForecaster(sheets_id, gid, log_fn=_log_fn)
fc.load_data()
fc.run_forecast(predict_days=predict_days)
return fc
# ── 사이드바 ─────────────────────────────────────────────────
with st.sidebar:
st.markdown(
"<div style='margin-bottom:0.3rem;'>"
"<span style='font-size:1rem; font-weight:700; color:#e8e8e8; letter-spacing:-0.01em;'>"
"종편 4사 시청률 예측 모델</span>"
"</div>"
"<div style='font-size:0.7rem; color:#888; margin-bottom:1.5rem;'>"
"Powered by <b style='color:#ccc;'>TimeFM</b> (by Google)</div>",
unsafe_allow_html=True,
)
st.markdown('<div class="sidebar-title">Data Source</div>', unsafe_allow_html=True)
sheets_id = st.text_input(
"Google Sheets ID",
value="1uv9gNT9TDEu2qtPPOnQlhiznnb4lxmogwQFWmQbclIc",
label_visibility="visible",
)
gid = st.text_input(
"Sheet GID",
value="0",
help="GID는 시트 탭의 고유 번호입니다. 첫 번째 시트는 0, 이후 시트는 URL의 gid= 값을 확인하세요.",
)
st.markdown(
"<div class='help-text'>GID <b style='color:#ccc;'>0</b> = 첫 번째 시트 (기본값)<br>"
"다른 시트 사용 시 URL에서 <code style='color:#ccc;'>gid=숫자</code> 확인</div>",
unsafe_allow_html=True,
)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown('<div class="sidebar-title">Forecast</div>', unsafe_allow_html=True)
predict_days = st.selectbox(
"예측 기간",
[30, 60, 90, 180],
index=3,
format_func=lambda x: f"{x}일",
)
st.markdown("<br>", unsafe_allow_html=True)
run = st.button("Run Forecast", use_container_width=True, type="primary")
st.markdown("<br>", unsafe_allow_html=True)
st.markdown('<div class="sidebar-title">About TimeFM</div>', unsafe_allow_html=True)
st.markdown(
"<span style='font-size:0.75rem; color:#bbb; line-height:2.0;'>"
"Google Research가 개발한 <b style='color:#ccc;'>시계열 예측 파운데이션 모델</b>입니다.<br>"
"수억 개의 시계열 데이터로 사전 학습된 <b style='color:#ccc;'>머신러닝(Transformer)</b> 기반 모델로, "
"별도 학습 없이 새로운 데이터에 바로 적용 가능한 "
"<b style='color:#ccc;'>제로샷(Zero-shot)</b> 방식으로 동작합니다.<br><br>"
"200M 파라미터 · PyTorch<br><br>"
"<a href='https://arxiv.org/abs/2310.10688' target='_blank' "
"style='color:#888; text-decoration:none;'>📄 arXiv 논문</a><br>"
"<a href='https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/' target='_blank' "
"style='color:#888; text-decoration:none;'>📝 Google Research 블로그</a><br>"
"<a href='https://huggingface.co/google/timesfm-1.0-200m-pytorch' target='_blank' "
"style='color:#888; text-decoration:none;'>🤗 HuggingFace 모델</a><br>"
"<a href='https://github.com/google-research/timesfm' target='_blank' "
"style='color:#888; text-decoration:none;'>💻 GitHub</a>"
"</span>",
unsafe_allow_html=True,
)
# ── 예측 실행 ─────────────────────────────────────────────────
if run and sheets_id:
st.session_state["sheets_id"] = sheets_id
st.session_state["gid"] = gid
st.session_state["predict_days"] = predict_days
get_forecaster.clear()
if "sheets_id" not in st.session_state:
st.markdown(
"<div style='height:60vh; display:flex; align-items:center; justify-content:center;"
" flex-direction:column; gap:0.5rem;'>"
"<div style='font-size:0.75rem; letter-spacing:0.1em; color:#555; text-transform:uppercase;'>"
"종편 4사 시청률 예측 모델</div>"
"<div style='font-size:1.5rem; color:#444;'>← Run Forecast to begin</div>"
"</div>",
unsafe_allow_html=True,
)
st.stop()
_log_buffer = []
_forecast_error = None
with st.status("TimeFM 분석 중...", expanded=True) as _status:
try:
forecaster = get_forecaster(
st.session_state["sheets_id"],
st.session_state["gid"],
st.session_state["predict_days"],
_log_fn=lambda msg: _log_buffer.append(msg),
)
for msg in _log_buffer:
_status.write(msg)
_status.update(label="분석 완료 ✓", state="complete", expanded=False)
except Exception as e:
import traceback
_forecast_error = traceback.format_exc()
_status.update(label="오류 발생", state="error")
if _forecast_error:
st.error(f"**오류 상세**\n```\n{_forecast_error}\n```")
st.stop()
target_dt = pd.to_datetime(forecaster.df["날짜"].max()).normalize() + pd.Timedelta(days=1)
preds = forecaster.get_today_predictions(target_dt)
predict_days = forecaster.predict_days
# ── 메인 타이틀 ───────────────────────────────────────────────
st.markdown(
"<div style='margin-bottom:0.3rem;'>"
"<span style='font-size:1.6rem; font-weight:700; color:#ffffff; letter-spacing:-0.02em;'>"
"종편 4사 메인뉴스 시청률 예보</span>"
"</div>"
"<div style='font-size:0.78rem; color:#888; margin-bottom:1.8rem;'>"
"Powered by <b style='color:#bbb;'>Google TimeFM</b></div>",
unsafe_allow_html=True,
)
# ── 날짜 헤더 ─────────────────────────────────────────────────
st.markdown(
f"<div style='font-size:0.7rem; letter-spacing:0.1em; color:#999;"
f" text-transform:uppercase; margin-bottom:1rem;'>"
f"Forecast · {target_dt.strftime('%Y.%m.%d')}</div>",
unsafe_allow_html=True,
)
# ── 채널 카드 ─────────────────────────────────────────────────
cols = st.columns(4, gap="small")
for i, ch in enumerate(ORDER):
p = preds[ch]
c = COLORS[ch]
with cols[i]:
st.markdown(
f"""<div class="ch-card" style="--ch-color:{c};">
<div class="ch-name">{LABELS[ch]}</div>
<div class="ch-value">{p['forecast']:.3f}<span class="ch-unit">%</span></div>
<div class="ch-ci">
<span>90%</span> {p['lower_90']:.3f}{p['upper_90']:.3f}<br>
<span>95%</span> {p['lower_95']:.3f}{p['upper_95']:.3f}
</div>
</div>""",
unsafe_allow_html=True,
)
st.markdown("<br>", unsafe_allow_html=True)
# ── 탭 ───────────────────────────────────────────────────────
tab1, tab2, tab3 = st.tabs(["TREND", "CHANNEL", "EXPORT"])
# ── TREND 탭 ─────────────────────────────────────────────────
with tab1:
day_filter = st.radio(
"",
["All", "Weekday", "Weekend"],
horizontal=True,
key="day_filter",
label_visibility="collapsed",
)
fig = go.Figure()
for ch in ORDER:
fc = forecaster.forecasts[ch]
hist_col = [k for k, v in forecaster.channels.items() if v == ch][0]
hist_df = forecaster.df[["날짜", hist_col]].copy()
hist_df.columns = ["ds", "y"]
if day_filter == "Weekday":
fc = fc[pd.to_datetime(fc["ds"]).dt.dayofweek < 5]
hist_df = hist_df[pd.to_datetime(hist_df["ds"]).dt.dayofweek < 5]
elif day_filter == "Weekend":
fc = fc[pd.to_datetime(fc["ds"]).dt.dayofweek >= 5]
hist_df = hist_df[pd.to_datetime(hist_df["ds"]).dt.dayofweek >= 5]
color = COLORS[ch]
r, g, b = int(color[1:3],16), int(color[3:5],16), int(color[5:7],16)
fig.add_trace(go.Scatter(
x=hist_df["ds"], y=hist_df["y"],
name=LABELS[ch],
line=dict(color=color, width=1.5),
legendgroup=ch,
))
fig.add_trace(go.Scatter(
x=fc["ds"], y=fc["yhat"],
name=f"{LABELS[ch]} forecast",
line=dict(color=color, width=2, dash="dot"),
legendgroup=ch, showlegend=False,
))
fig.add_trace(go.Scatter(
x=pd.concat([fc["ds"], fc["ds"][::-1]]),
y=pd.concat([fc["yhat_upper"], fc["yhat_lower"][::-1]]),
fill="toself",
fillcolor=f"rgba({r},{g},{b},0.08)",
line=dict(color="rgba(0,0,0,0)"),
legendgroup=ch, showlegend=False, hoverinfo="skip",
))
fig.add_vline(
x=target_dt.timestamp() * 1000,
line_dash="dot", line_color="#444",
annotation_text="forecast start",
annotation_font_color="#888",
annotation_font_size=10,
)
fig.update_layout(
**CHART_LAYOUT, height=460, yaxis_title="시청률 (%)",
legend=dict(orientation="h", y=-0.12, font=dict(color="#ccc")),
)
st.plotly_chart(fig, use_container_width=True)
# ── CHANNEL 탭 ───────────────────────────────────────────────
with tab2:
selected_ch = st.selectbox(
"",
ORDER,
format_func=lambda x: LABELS[x],
key="channel_select",
label_visibility="collapsed",
)
fc = forecaster.forecasts[selected_ch]
hist_col = [k for k, v in forecaster.channels.items() if v == selected_ch][0]
hist_df = forecaster.df[["날짜", hist_col]].copy()
hist_df.columns = ["ds", "y"]
color = COLORS[selected_ch]
r, g, b = int(color[1:3],16), int(color[3:5],16), int(color[5:7],16)
fig2 = go.Figure()
fig2.add_trace(go.Scatter(
x=pd.concat([fc["ds"], fc["ds"][::-1]]),
y=pd.concat([fc["yhat_upper"], fc["yhat_lower"][::-1]]),
fill="toself", fillcolor=f"rgba({r},{g},{b},0.07)",
line=dict(color="rgba(0,0,0,0)"), name="95% CI", hoverinfo="skip",
))
fig2.add_trace(go.Scatter(
x=pd.concat([fc["ds"], fc["ds"][::-1]]),
y=pd.concat([fc["yhat_upper_90"], fc["yhat_lower_90"][::-1]]),
fill="toself", fillcolor=f"rgba({r},{g},{b},0.13)",
line=dict(color="rgba(0,0,0,0)"), name="90% CI", hoverinfo="skip",
))
fig2.add_trace(go.Scatter(
x=hist_df["ds"], y=hist_df["y"],
name="actual", line=dict(color="#777", width=1.5),
))
fig2.add_trace(go.Scatter(
x=fc["ds"], y=fc["yhat"],
name="forecast", line=dict(color=color, width=2, dash="dot"),
))
fig2.add_vline(
x=target_dt.timestamp() * 1000,
line_dash="dot", line_color="#444",
)
fig2.update_layout(
**CHART_LAYOUT, height=420,
title=dict(text=LABELS[selected_ch], font=dict(color="#ccc", size=12)),
yaxis_title="시청률 (%)",
legend=dict(orientation="h", y=-0.14, font=dict(color="#ccc")),
)
st.plotly_chart(fig2, use_container_width=True)
# 요약 테이블
st.markdown('<div class="section-label">Summary</div>', unsafe_allow_html=True)
rows = []
for ch in ORDER:
p = preds[ch]
rows.append({
"채널": LABELS[ch],
"예측값": f"{p['forecast']:.3f}%",
"90% 하한": f"{p['lower_90']:.3f}%",
"90% 상한": f"{p['upper_90']:.3f}%",
"95% 하한": f"{p['lower_95']:.3f}%",
"95% 상한": f"{p['upper_95']:.3f}%",
})
st.dataframe(pd.DataFrame(rows), hide_index=True, use_container_width=True)
# ── EXPORT 탭 ────────────────────────────────────────────────
with tab3:
st.markdown('<div class="section-label">4개 채널 날짜별 예측 데이터</div>', unsafe_allow_html=True)
# 날짜별 Wide 포맷 (4개 채널 한 행에)
wide_parts = []
for ch in ORDER:
fc_ch = forecaster.forecasts[ch][["ds", "yhat", "yhat_lower_90", "yhat_upper_90", "yhat_lower", "yhat_upper"]].copy()
fc_ch.columns = [
"날짜",
f"{LABELS[ch]}_예측",
f"{LABELS[ch]}_90하한",
f"{LABELS[ch]}_90상한",
f"{LABELS[ch]}_95하한",
f"{LABELS[ch]}_95상한",
]
wide_parts.append(fc_ch.set_index("날짜"))
wide_df = pd.concat(wide_parts, axis=1).reset_index()
wide_df["날짜"] = pd.to_datetime(wide_df["날짜"]).dt.strftime("%Y-%m-%d")
wide_df = wide_df.sort_values("날짜").reset_index(drop=True)
st.dataframe(wide_df.head(60), hide_index=True, use_container_width=True)
c1, c2 = st.columns(2)
with c1:
st.download_button(
"전체 예측기간 다운로드",
data=wide_df.to_csv(index=False, encoding="utf-8-sig"),
file_name=f"timefm_4ch_{target_dt.strftime('%Y%m%d')}_full.csv",
mime="text/csv",
use_container_width=True,
)
with c2:
today_row = wide_df[wide_df["날짜"] == target_dt.strftime("%Y-%m-%d")]
if not today_row.empty:
st.download_button(
f"{target_dt.strftime('%Y-%m-%d')} 하루치 다운로드",
data=today_row.to_csv(index=False, encoding="utf-8-sig"),
file_name=f"timefm_4ch_{target_dt.strftime('%Y%m%d')}.csv",
mime="text/csv",
use_container_width=True,
)