Spaces:
Running
Running
| """ | |
| 종편 4사 시청률 예측 모델 — Google TimeFM 버전 | |
| """ | |
| import streamlit as st | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| st.set_page_config( | |
| page_title="종편 4사 시청률 예측 · TimeFM", | |
| page_icon="▲", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # ── CSS ────────────────────────────────────────────────────── | |
| st.markdown(""" | |
| <style> | |
| /* 전체 배경 */ | |
| [data-testid="stAppViewContainer"] { background: #141414; } | |
| [data-testid="stSidebar"] { background: #1a1a1a; border-right: 1px solid #2a2a2a; } | |
| [data-testid="stHeader"] { background: #141414; } | |
| /* 기본 텍스트 */ | |
| html, body, [class*="css"] { color: #e8e8e8; } | |
| /* 사이드바 타이틀 */ | |
| .sidebar-title { | |
| font-size: 0.7rem; | |
| font-weight: 600; | |
| letter-spacing: 0.12em; | |
| color: #ccc; | |
| text-transform: uppercase; | |
| margin-bottom: 1rem; | |
| padding-bottom: 0.5rem; | |
| border-bottom: 1px solid #333; | |
| } | |
| /* 채널 카드 — 좌측 컬러 보더 */ | |
| .ch-card { | |
| background: #1e1e1e; | |
| border-left: 3px solid var(--ch-color); | |
| border-radius: 0 8px 8px 0; | |
| padding: 1rem 1.2rem; | |
| margin-bottom: 0; | |
| } | |
| .ch-name { | |
| font-size: 0.75rem; | |
| font-weight: 600; | |
| letter-spacing: 0.05em; | |
| color: var(--ch-color); | |
| text-transform: uppercase; | |
| margin-bottom: 0.4rem; | |
| } | |
| .ch-value { | |
| font-size: 1.9rem; | |
| font-weight: 700; | |
| color: #ffffff; | |
| letter-spacing: -0.02em; | |
| line-height: 1; | |
| } | |
| .ch-unit { font-size: 1rem; color: #bbb; margin-left: 2px; } | |
| .ch-ci { | |
| font-size: 0.72rem; | |
| color: #999; | |
| margin-top: 0.5rem; | |
| line-height: 1.6; | |
| } | |
| .ch-ci span { color: #bbb; } | |
| /* 섹션 헤더 */ | |
| .section-label { | |
| font-size: 0.7rem; | |
| font-weight: 600; | |
| letter-spacing: 0.1em; | |
| color: #ccc; | |
| text-transform: uppercase; | |
| margin-bottom: 1rem; | |
| } | |
| /* 탭 */ | |
| [data-testid="stTabs"] button { | |
| font-size: 0.8rem; | |
| letter-spacing: 0.04em; | |
| color: #888 !important; | |
| border-bottom: 2px solid transparent !important; | |
| } | |
| [data-testid="stTabs"] button[aria-selected="true"] { | |
| color: #e8e8e8 !important; | |
| border-bottom: 2px solid #e8e8e8 !important; | |
| } | |
| /* 버튼 */ | |
| [data-testid="stButton"] > button[kind="primary"] { | |
| background: #e8e8e8; | |
| color: #141414; | |
| border: none; | |
| border-radius: 6px; | |
| font-weight: 600; | |
| font-size: 0.82rem; | |
| letter-spacing: 0.03em; | |
| padding: 0.5rem 1.2rem; | |
| } | |
| [data-testid="stButton"] > button[kind="primary"]:hover { | |
| background: #ffffff; | |
| } | |
| /* 인풋 */ | |
| [data-testid="stTextInput"] input, | |
| [data-testid="stSelectbox"] > div { | |
| background: #1e1e1e !important; | |
| border: 1px solid #2a2a2a !important; | |
| color: #e8e8e8 !important; | |
| border-radius: 6px !important; | |
| font-size: 0.82rem !important; | |
| } | |
| /* 인풋 레이블 */ | |
| [data-testid="stTextInput"] label, | |
| [data-testid="stSelectbox"] label { | |
| color: #ccc !important; | |
| font-size: 0.78rem !important; | |
| } | |
| /* 힌트 텍스트 */ | |
| .help-text { | |
| font-size: 0.7rem; | |
| color: #888; | |
| line-height: 1.6; | |
| margin-top: 0.3rem; | |
| } | |
| /* 구분선 */ | |
| hr { border-color: #2a2a2a; } | |
| /* 데이터프레임 */ | |
| [data-testid="stDataFrame"] { background: #1a1a1a; } | |
| /* 분석 로그 (st.status) 텍스트 밝기 */ | |
| [data-testid="stStatusWidget"] p, | |
| [data-testid="stStatusWidget"] span, | |
| [data-testid="stStatusWidget"] div, | |
| [data-testid="stExpander"] [data-testid="stMarkdownContainer"] p, | |
| [data-testid="stExpander"] [data-testid="stMarkdownContainer"] span { | |
| color: #e8e8e8 !important; | |
| font-size: 0.82rem !important; | |
| line-height: 1.8 !important; | |
| } | |
| /* ── 모바일 반응형 ─────────────────────────────────── */ | |
| @media (max-width: 768px) { | |
| /* 채널 카드: 4열 → 2열 */ | |
| [data-testid="column"] { | |
| width: 50% !important; | |
| flex: 0 0 50% !important; | |
| min-width: 50% !important; | |
| } | |
| /* 카드 내 값 폰트 축소 */ | |
| .ch-value { font-size: 1.5rem !important; } | |
| .ch-ci { font-size: 0.75rem !important; } | |
| .ch-name { font-size: 0.72rem !important; } | |
| /* 탭 버튼 크기 확대 */ | |
| [data-testid="stTabs"] button { | |
| font-size: 0.85rem !important; | |
| padding: 0.5rem 0.8rem !important; | |
| } | |
| /* 여백 축소 */ | |
| [data-testid="stAppViewContainer"] > section { padding: 0.5rem !important; } | |
| } | |
| @media (max-width: 480px) { | |
| /* 채널 카드: 2열 → 1열 */ | |
| [data-testid="column"] { | |
| width: 100% !important; | |
| flex: 0 0 100% !important; | |
| min-width: 100% !important; | |
| } | |
| .ch-value { font-size: 1.8rem !important; } | |
| .ch-ci { font-size: 0.8rem !important; } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| COLORS = { | |
| "News_A": "#3B82F6", | |
| "JTBC": "#A855F7", | |
| "MBN": "#F59E0B", | |
| "TVCHOSUN": "#EF4444", | |
| } | |
| LABELS = { | |
| "News_A": "뉴스A", | |
| "JTBC": "JTBC뉴스룸", | |
| "MBN": "MBN뉴스7", | |
| "TVCHOSUN": "TV조선뉴스9", | |
| } | |
| ORDER = ["News_A", "JTBC", "MBN", "TVCHOSUN"] | |
| def hex_to_rgba(hex_color, alpha): | |
| r = int(hex_color[1:3], 16) | |
| g = int(hex_color[3:5], 16) | |
| b = int(hex_color[5:7], 16) | |
| return f"rgba({r},{g},{b},{alpha})" | |
| CHART_LAYOUT = dict( | |
| template="plotly_dark", | |
| paper_bgcolor="#141414", | |
| plot_bgcolor="#141414", | |
| font=dict(family="Inter, -apple-system, sans-serif", color="#bbb", size=11), | |
| xaxis=dict(gridcolor="#222", zeroline=False, showline=False), | |
| yaxis=dict(gridcolor="#222", zeroline=False, showline=False), | |
| hovermode="x unified", | |
| hoverlabel=dict(bgcolor="#1e1e1e", bordercolor="#333", font_color="#e8e8e8"), | |
| margin=dict(l=0, r=0, t=36, b=0), | |
| ) | |
| # ── 모델 캐시 ───────────────────────────────────────────────── | |
| def get_forecaster(sheets_id, gid, predict_days, _log_fn=None): | |
| from timefm_forecaster import TimeFMForecaster | |
| fc = TimeFMForecaster(sheets_id, gid, log_fn=_log_fn) | |
| fc.load_data() | |
| fc.run_forecast(predict_days=predict_days) | |
| return fc | |
| # ── 사이드바 ───────────────────────────────────────────────── | |
| with st.sidebar: | |
| st.markdown( | |
| "<div style='margin-bottom:0.3rem;'>" | |
| "<span style='font-size:1rem; font-weight:700; color:#e8e8e8; letter-spacing:-0.01em;'>" | |
| "종편 4사 시청률 예측 모델</span>" | |
| "</div>" | |
| "<div style='font-size:0.7rem; color:#888; margin-bottom:1.5rem;'>" | |
| "Powered by <b style='color:#ccc;'>TimeFM</b> (by Google)</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown('<div class="sidebar-title">Data Source</div>', unsafe_allow_html=True) | |
| sheets_id = st.text_input( | |
| "Google Sheets ID", | |
| value="1uv9gNT9TDEu2qtPPOnQlhiznnb4lxmogwQFWmQbclIc", | |
| label_visibility="visible", | |
| ) | |
| gid = st.text_input( | |
| "Sheet GID", | |
| value="0", | |
| help="GID는 시트 탭의 고유 번호입니다. 첫 번째 시트는 0, 이후 시트는 URL의 gid= 값을 확인하세요.", | |
| ) | |
| st.markdown( | |
| "<div class='help-text'>GID <b style='color:#ccc;'>0</b> = 첫 번째 시트 (기본값)<br>" | |
| "다른 시트 사용 시 URL에서 <code style='color:#ccc;'>gid=숫자</code> 확인</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| st.markdown('<div class="sidebar-title">Forecast</div>', unsafe_allow_html=True) | |
| predict_days = st.selectbox( | |
| "예측 기간", | |
| [30, 60, 90, 180], | |
| index=3, | |
| format_func=lambda x: f"{x}일", | |
| ) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| run = st.button("Run Forecast", use_container_width=True, type="primary") | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| st.markdown('<div class="sidebar-title">About TimeFM</div>', unsafe_allow_html=True) | |
| st.markdown( | |
| "<span style='font-size:0.75rem; color:#bbb; line-height:2.0;'>" | |
| "Google Research가 개발한 <b style='color:#ccc;'>시계열 예측 파운데이션 모델</b>입니다.<br>" | |
| "수억 개의 시계열 데이터로 사전 학습된 <b style='color:#ccc;'>머신러닝(Transformer)</b> 기반 모델로, " | |
| "별도 학습 없이 새로운 데이터에 바로 적용 가능한 " | |
| "<b style='color:#ccc;'>제로샷(Zero-shot)</b> 방식으로 동작합니다.<br><br>" | |
| "200M 파라미터 · PyTorch<br><br>" | |
| "<a href='https://arxiv.org/abs/2310.10688' target='_blank' " | |
| "style='color:#888; text-decoration:none;'>📄 arXiv 논문</a><br>" | |
| "<a href='https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/' target='_blank' " | |
| "style='color:#888; text-decoration:none;'>📝 Google Research 블로그</a><br>" | |
| "<a href='https://huggingface.co/google/timesfm-1.0-200m-pytorch' target='_blank' " | |
| "style='color:#888; text-decoration:none;'>🤗 HuggingFace 모델</a><br>" | |
| "<a href='https://github.com/google-research/timesfm' target='_blank' " | |
| "style='color:#888; text-decoration:none;'>💻 GitHub</a>" | |
| "</span>", | |
| unsafe_allow_html=True, | |
| ) | |
| # ── 예측 실행 ───────────────────────────────────────────────── | |
| if run and sheets_id: | |
| st.session_state["sheets_id"] = sheets_id | |
| st.session_state["gid"] = gid | |
| st.session_state["predict_days"] = predict_days | |
| get_forecaster.clear() | |
| if "sheets_id" not in st.session_state: | |
| st.markdown( | |
| "<div style='height:60vh; display:flex; align-items:center; justify-content:center;" | |
| " flex-direction:column; gap:0.5rem;'>" | |
| "<div style='font-size:0.75rem; letter-spacing:0.1em; color:#555; text-transform:uppercase;'>" | |
| "종편 4사 시청률 예측 모델</div>" | |
| "<div style='font-size:1.5rem; color:#444;'>← Run Forecast to begin</div>" | |
| "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.stop() | |
| _log_buffer = [] | |
| _forecast_error = None | |
| with st.status("TimeFM 분석 중...", expanded=True) as _status: | |
| try: | |
| forecaster = get_forecaster( | |
| st.session_state["sheets_id"], | |
| st.session_state["gid"], | |
| st.session_state["predict_days"], | |
| _log_fn=lambda msg: _log_buffer.append(msg), | |
| ) | |
| for msg in _log_buffer: | |
| _status.write(msg) | |
| _status.update(label="분석 완료 ✓", state="complete", expanded=False) | |
| except Exception as e: | |
| import traceback | |
| _forecast_error = traceback.format_exc() | |
| _status.update(label="오류 발생", state="error") | |
| if _forecast_error: | |
| st.error(f"**오류 상세**\n```\n{_forecast_error}\n```") | |
| st.stop() | |
| target_dt = pd.to_datetime(forecaster.df["날짜"].max()).normalize() + pd.Timedelta(days=1) | |
| preds = forecaster.get_today_predictions(target_dt) | |
| predict_days = forecaster.predict_days | |
| # ── 메인 타이틀 ─────────────────────────────────────────────── | |
| st.markdown( | |
| "<div style='margin-bottom:0.3rem;'>" | |
| "<span style='font-size:1.6rem; font-weight:700; color:#ffffff; letter-spacing:-0.02em;'>" | |
| "종편 4사 메인뉴스 시청률 예보</span>" | |
| "</div>" | |
| "<div style='font-size:0.78rem; color:#888; margin-bottom:1.8rem;'>" | |
| "Powered by <b style='color:#bbb;'>Google TimeFM</b></div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # ── 날짜 헤더 ───────────────────────────────────────────────── | |
| st.markdown( | |
| f"<div style='font-size:0.7rem; letter-spacing:0.1em; color:#999;" | |
| f" text-transform:uppercase; margin-bottom:1rem;'>" | |
| f"Forecast · {target_dt.strftime('%Y.%m.%d')}</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # ── 채널 카드 ───────────────────────────────────────────────── | |
| cols = st.columns(4, gap="small") | |
| for i, ch in enumerate(ORDER): | |
| p = preds[ch] | |
| c = COLORS[ch] | |
| with cols[i]: | |
| st.markdown( | |
| f"""<div class="ch-card" style="--ch-color:{c};"> | |
| <div class="ch-name">{LABELS[ch]}</div> | |
| <div class="ch-value">{p['forecast']:.3f}<span class="ch-unit">%</span></div> | |
| <div class="ch-ci"> | |
| <span>90%</span> {p['lower_90']:.3f} – {p['upper_90']:.3f}<br> | |
| <span>95%</span> {p['lower_95']:.3f} – {p['upper_95']:.3f} | |
| </div> | |
| </div>""", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # ── 탭 ─────────────────────────────────────────────────────── | |
| tab1, tab2, tab3 = st.tabs(["TREND", "CHANNEL", "EXPORT"]) | |
| # ── TREND 탭 ───────────────────────────────────────────────── | |
| with tab1: | |
| day_filter = st.radio( | |
| "", | |
| ["All", "Weekday", "Weekend"], | |
| horizontal=True, | |
| key="day_filter", | |
| label_visibility="collapsed", | |
| ) | |
| fig = go.Figure() | |
| for ch in ORDER: | |
| fc = forecaster.forecasts[ch] | |
| hist_col = [k for k, v in forecaster.channels.items() if v == ch][0] | |
| hist_df = forecaster.df[["날짜", hist_col]].copy() | |
| hist_df.columns = ["ds", "y"] | |
| if day_filter == "Weekday": | |
| fc = fc[pd.to_datetime(fc["ds"]).dt.dayofweek < 5] | |
| hist_df = hist_df[pd.to_datetime(hist_df["ds"]).dt.dayofweek < 5] | |
| elif day_filter == "Weekend": | |
| fc = fc[pd.to_datetime(fc["ds"]).dt.dayofweek >= 5] | |
| hist_df = hist_df[pd.to_datetime(hist_df["ds"]).dt.dayofweek >= 5] | |
| color = COLORS[ch] | |
| r, g, b = int(color[1:3],16), int(color[3:5],16), int(color[5:7],16) | |
| fig.add_trace(go.Scatter( | |
| x=hist_df["ds"], y=hist_df["y"], | |
| name=LABELS[ch], | |
| line=dict(color=color, width=1.5), | |
| legendgroup=ch, | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=fc["ds"], y=fc["yhat"], | |
| name=f"{LABELS[ch]} forecast", | |
| line=dict(color=color, width=2, dash="dot"), | |
| legendgroup=ch, showlegend=False, | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=pd.concat([fc["ds"], fc["ds"][::-1]]), | |
| y=pd.concat([fc["yhat_upper"], fc["yhat_lower"][::-1]]), | |
| fill="toself", | |
| fillcolor=f"rgba({r},{g},{b},0.08)", | |
| line=dict(color="rgba(0,0,0,0)"), | |
| legendgroup=ch, showlegend=False, hoverinfo="skip", | |
| )) | |
| fig.add_vline( | |
| x=target_dt.timestamp() * 1000, | |
| line_dash="dot", line_color="#444", | |
| annotation_text="forecast start", | |
| annotation_font_color="#888", | |
| annotation_font_size=10, | |
| ) | |
| fig.update_layout( | |
| **CHART_LAYOUT, height=460, yaxis_title="시청률 (%)", | |
| legend=dict(orientation="h", y=-0.12, font=dict(color="#ccc")), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # ── CHANNEL 탭 ─────────────────────────────────────────────── | |
| with tab2: | |
| selected_ch = st.selectbox( | |
| "", | |
| ORDER, | |
| format_func=lambda x: LABELS[x], | |
| key="channel_select", | |
| label_visibility="collapsed", | |
| ) | |
| fc = forecaster.forecasts[selected_ch] | |
| hist_col = [k for k, v in forecaster.channels.items() if v == selected_ch][0] | |
| hist_df = forecaster.df[["날짜", hist_col]].copy() | |
| hist_df.columns = ["ds", "y"] | |
| color = COLORS[selected_ch] | |
| r, g, b = int(color[1:3],16), int(color[3:5],16), int(color[5:7],16) | |
| fig2 = go.Figure() | |
| fig2.add_trace(go.Scatter( | |
| x=pd.concat([fc["ds"], fc["ds"][::-1]]), | |
| y=pd.concat([fc["yhat_upper"], fc["yhat_lower"][::-1]]), | |
| fill="toself", fillcolor=f"rgba({r},{g},{b},0.07)", | |
| line=dict(color="rgba(0,0,0,0)"), name="95% CI", hoverinfo="skip", | |
| )) | |
| fig2.add_trace(go.Scatter( | |
| x=pd.concat([fc["ds"], fc["ds"][::-1]]), | |
| y=pd.concat([fc["yhat_upper_90"], fc["yhat_lower_90"][::-1]]), | |
| fill="toself", fillcolor=f"rgba({r},{g},{b},0.13)", | |
| line=dict(color="rgba(0,0,0,0)"), name="90% CI", hoverinfo="skip", | |
| )) | |
| fig2.add_trace(go.Scatter( | |
| x=hist_df["ds"], y=hist_df["y"], | |
| name="actual", line=dict(color="#777", width=1.5), | |
| )) | |
| fig2.add_trace(go.Scatter( | |
| x=fc["ds"], y=fc["yhat"], | |
| name="forecast", line=dict(color=color, width=2, dash="dot"), | |
| )) | |
| fig2.add_vline( | |
| x=target_dt.timestamp() * 1000, | |
| line_dash="dot", line_color="#444", | |
| ) | |
| fig2.update_layout( | |
| **CHART_LAYOUT, height=420, | |
| title=dict(text=LABELS[selected_ch], font=dict(color="#ccc", size=12)), | |
| yaxis_title="시청률 (%)", | |
| legend=dict(orientation="h", y=-0.14, font=dict(color="#ccc")), | |
| ) | |
| st.plotly_chart(fig2, use_container_width=True) | |
| # 요약 테이블 | |
| st.markdown('<div class="section-label">Summary</div>', unsafe_allow_html=True) | |
| rows = [] | |
| for ch in ORDER: | |
| p = preds[ch] | |
| rows.append({ | |
| "채널": LABELS[ch], | |
| "예측값": f"{p['forecast']:.3f}%", | |
| "90% 하한": f"{p['lower_90']:.3f}%", | |
| "90% 상한": f"{p['upper_90']:.3f}%", | |
| "95% 하한": f"{p['lower_95']:.3f}%", | |
| "95% 상한": f"{p['upper_95']:.3f}%", | |
| }) | |
| st.dataframe(pd.DataFrame(rows), hide_index=True, use_container_width=True) | |
| # ── EXPORT 탭 ──────────────────────────────────────────────── | |
| with tab3: | |
| st.markdown('<div class="section-label">4개 채널 날짜별 예측 데이터</div>', unsafe_allow_html=True) | |
| # 날짜별 Wide 포맷 (4개 채널 한 행에) | |
| wide_parts = [] | |
| for ch in ORDER: | |
| fc_ch = forecaster.forecasts[ch][["ds", "yhat", "yhat_lower_90", "yhat_upper_90", "yhat_lower", "yhat_upper"]].copy() | |
| fc_ch.columns = [ | |
| "날짜", | |
| f"{LABELS[ch]}_예측", | |
| f"{LABELS[ch]}_90하한", | |
| f"{LABELS[ch]}_90상한", | |
| f"{LABELS[ch]}_95하한", | |
| f"{LABELS[ch]}_95상한", | |
| ] | |
| wide_parts.append(fc_ch.set_index("날짜")) | |
| wide_df = pd.concat(wide_parts, axis=1).reset_index() | |
| wide_df["날짜"] = pd.to_datetime(wide_df["날짜"]).dt.strftime("%Y-%m-%d") | |
| wide_df = wide_df.sort_values("날짜").reset_index(drop=True) | |
| st.dataframe(wide_df.head(60), hide_index=True, use_container_width=True) | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| st.download_button( | |
| "전체 예측기간 다운로드", | |
| data=wide_df.to_csv(index=False, encoding="utf-8-sig"), | |
| file_name=f"timefm_4ch_{target_dt.strftime('%Y%m%d')}_full.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| ) | |
| with c2: | |
| today_row = wide_df[wide_df["날짜"] == target_dt.strftime("%Y-%m-%d")] | |
| if not today_row.empty: | |
| st.download_button( | |
| f"{target_dt.strftime('%Y-%m-%d')} 하루치 다운로드", | |
| data=today_row.to_csv(index=False, encoding="utf-8-sig"), | |
| file_name=f"timefm_4ch_{target_dt.strftime('%Y%m%d')}.csv", | |
| mime="text/csv", | |
| use_container_width=True, | |
| ) | |