Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import plotly.graph_objects as go | |
| import joblib | |
| from transformers import ( | |
| TimesFmModelForPrediction, | |
| PatchTSTConfig, | |
| PatchTSTForPrediction | |
| ) | |
| from nixtla import NixtlaClient | |
| from sklearn.metrics import mean_absolute_error, root_mean_squared_error | |
| from datetime import datetime | |
| from statsmodels.tsa.statespace.sarimax import SARIMAX | |
| from vnstock import Vnstock | |
| import time | |
| from datetime import timedelta | |
| import warnings | |
| warnings.filterwarnings("ignore", message="To copy construct from a tensor") | |
| st.markdown(""" | |
| <style> | |
| /* Import fonts */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&family=JetBrains+Mono:wght@400;500;600&display=swap'); | |
| /* Root Variables */ | |
| :root { | |
| --primary: #00D9FF; | |
| --secondary: #A855F7; | |
| --accent: #FFD93D; | |
| --success: #10B981; | |
| --warning: #F59E0B; | |
| --error: #EF4444; | |
| --dark-bg: #0A0E1A; | |
| --dark-surface: #12182B; | |
| --dark-elevated: #1E293B; | |
| --glass: rgba(18, 24, 43, 0.6); | |
| --glass-border: rgba(255, 255, 255, 0.12); | |
| --text-primary: #F9FAFB; | |
| --text-secondary: #E5E7EB; | |
| --text-muted: #9CA3AF; | |
| } | |
| /* Global Styles */ | |
| .stApp { | |
| background: linear-gradient(135deg, var(--dark-bg) 0%, #12182B 50%, var(--dark-bg) 100%); | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| } | |
| /* Animated Background với nhiều layer */ | |
| .stApp::before { | |
| content: ""; | |
| position: fixed; | |
| inset: 0; | |
| background: | |
| radial-gradient(circle at 20% 30%, rgba(0, 217, 255, 0.12) 0%, transparent 50%), | |
| radial-gradient(circle at 80% 70%, rgba(168, 85, 247, 0.10) 0%, transparent 50%), | |
| radial-gradient(circle at 50% 50%, rgba(255, 217, 61, 0.05) 0%, transparent 60%); | |
| animation: bgPulse 20s ease-in-out infinite; | |
| pointer-events: none; | |
| z-index: 0; | |
| } | |
| @keyframes bgPulse { | |
| 0%, 100% { | |
| opacity: 1; | |
| transform: scale(1) rotate(0deg); | |
| } | |
| 25% { | |
| opacity: 0.8; | |
| transform: scale(1.05) rotate(1deg); | |
| } | |
| 50% { | |
| opacity: 0.7; | |
| transform: scale(1.1) rotate(0deg); | |
| } | |
| 75% { | |
| opacity: 0.8; | |
| transform: scale(1.05) rotate(-1deg); | |
| } | |
| } | |
| /* Grid Pattern với animation */ | |
| .stApp::after { | |
| content: ""; | |
| position: fixed; | |
| inset: 0; | |
| background-image: | |
| linear-gradient(rgba(0, 217, 255, 0.03) 1px, transparent 1px), | |
| linear-gradient(90deg, rgba(0, 217, 255, 0.03) 1px, transparent 1px); | |
| background-size: 50px 50px; | |
| pointer-events: none; | |
| z-index: 1; | |
| animation: gridMove 30s linear infinite; | |
| } | |
| @keyframes gridMove { | |
| 0% { | |
| transform: translate(0, 0); | |
| opacity: 1; | |
| } | |
| 50% { | |
| opacity: 0.5; | |
| } | |
| 100% { | |
| transform: translate(50px, 50px); | |
| opacity: 1; | |
| } | |
| } | |
| /* Container */ | |
| .block-container { | |
| position: relative; | |
| z-index: 2; | |
| padding: 2rem 2.5rem; | |
| max-width: 1400px; | |
| } | |
| /* Typography */ | |
| h1 { | |
| font-weight: 800 !important; | |
| font-size: 3rem !important; | |
| background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| margin-bottom: 1.5rem !important; | |
| letter-spacing: -0.02em; | |
| text-align: center !important; | |
| animation: slideDown 0.8s cubic-bezier(0.16, 1, 0.3, 1); | |
| } | |
| h1 .emoji { | |
| -webkit-text-fill-color: currentColor !important; | |
| background: none !important; | |
| filter: drop-shadow(0 0 20px rgba(0, 217, 255, 0.5)); | |
| } | |
| h2 { | |
| color: #FFFFFF; | |
| font-weight: 700 !important; | |
| font-size: 1.75rem !important; | |
| margin-top: 3rem !important; | |
| margin-bottom: 1.5rem !important; | |
| padding-bottom: 0.75rem; | |
| border-bottom: 2px solid var(--dark-elevated); | |
| position: relative; | |
| animation: slideUp 0.6s cubic-bezier(0.16, 1, 0.3, 1); | |
| } | |
| h2::before { | |
| content: ""; | |
| position: absolute; | |
| left: 0; | |
| bottom: -2px; | |
| width: 80px; | |
| height: 2px; | |
| background: linear-gradient(90deg, var(--primary), var(--secondary)); | |
| border-radius: 2px; | |
| } | |
| h3 { | |
| color: var(--primary); | |
| font-weight: 600 !important; | |
| font-size: 1.25rem !important; | |
| margin-top: 2rem !important; | |
| margin-bottom: 1rem !important; | |
| } | |
| /* Paragraphs - TĂNG ĐỘ SÁNG */ | |
| p, li { | |
| color: #E5E7EB !important; | |
| line-height: 1.7; | |
| font-size: 1.05rem; | |
| font-weight: 400; | |
| } | |
| /* Chữ trong danh sách */ | |
| ul, ol { | |
| color: #E5E7EB !important; | |
| } | |
| ul li, ol li { | |
| color: #E5E7EB !important; | |
| margin-bottom: 0.5rem; | |
| } | |
| /* Text được highlight hoặc bold */ | |
| strong, b { | |
| color: #F9FAFB !important; | |
| font-weight: 600; | |
| } | |
| /* Sidebar */ | |
| [data-testid="stSidebar"] { | |
| background: linear-gradient(180deg, var(--dark-surface) 0%, var(--dark-bg) 100%); | |
| border-right: 1px solid var(--dark-elevated); | |
| } | |
| [data-testid="stSidebar"] [data-testid="stMarkdownContainer"] p { | |
| color: #E5E7EB !important; | |
| font-size: 0.95rem; | |
| } | |
| [data-testid="stSidebar"] .stRadio > label { | |
| color: #FFFFFF !important; | |
| font-weight: 600; | |
| font-size: 1rem; | |
| } | |
| /* Cards & Containers */ | |
| .hero-intro { | |
| background: var(--glass); | |
| backdrop-filter: blur(20px); | |
| border: 1px solid var(--glass-border); | |
| border-left: 3px solid var(--primary); | |
| border-radius: 16px; | |
| padding: 2rem; | |
| margin: 1.5rem 0; | |
| box-shadow: | |
| 0 8px 32px rgba(0, 0, 0, 0.3), | |
| inset 0 1px 0 rgba(255, 255, 255, 0.08); | |
| animation: slideUp 0.8s cubic-bezier(0.16, 1, 0.3, 1) 0.2s both; | |
| transition: all 0.3s ease; | |
| } | |
| .hero-intro:hover { | |
| border-left-color: var(--secondary); | |
| box-shadow: | |
| 0 12px 48px rgba(0, 0, 0, 0.4), | |
| 0 0 40px rgba(0, 217, 255, 0.15); | |
| } | |
| .hero-intro h3 { | |
| color: var(--primary); | |
| margin-top: 0 !important; | |
| margin-bottom: 1rem !important; | |
| } | |
| .hero-intro p { | |
| color: #F9FAFB !important; | |
| line-height: 1.7; | |
| font-size: 1.05rem; | |
| margin-bottom: 1rem; | |
| } | |
| .hero-intro ul { | |
| color: #E5E7EB !important; | |
| line-height: 1.8; | |
| margin-left: 1.5rem; | |
| } | |
| .hero-intro li { | |
| color: #E5E7EB !important; | |
| } | |
| .hero-intro strong { | |
| color: var(--secondary); | |
| font-weight: 600; | |
| } | |
| /* Date Display - CĂN LỀ TRÁI */ | |
| .date-display { | |
| background: linear-gradient(135deg, rgba(0, 217, 255, 0.08) 0%, rgba(168, 85, 247, 0.08) 100%); | |
| border: 1px solid var(--glass-border); | |
| border-radius: 12px; | |
| padding: 1rem 1.5rem; | |
| margin: 1.5rem 0; | |
| text-align: left; | |
| backdrop-filter: blur(10px); | |
| animation: fadeIn 0.8s ease; | |
| } | |
| .date-display p { | |
| color: #F9FAFB !important; | |
| font-size: 1.05rem; | |
| margin: 0; | |
| font-family: 'JetBrains Mono', monospace; | |
| } | |
| .date-display b { | |
| color: var(--primary); | |
| font-weight: 700; | |
| } | |
| /* Buttons */ | |
| .stButton > button { | |
| background: linear-gradient(135deg, #0099B8 0%, #7C3AAC 100%) !important; | |
| color: #FFFFFF !important; | |
| border: none !important; | |
| border-radius: 12px !important; | |
| padding: 0.85rem 2.5rem !important; | |
| font-weight: 700 !important; | |
| font-size: 1.05rem !important; | |
| letter-spacing: 0.05em; | |
| text-transform: uppercase; | |
| transition: all 0.4s cubic-bezier(0.16, 1, 0.3, 1) !important; | |
| box-shadow: | |
| 0 4px 20px rgba(0, 153, 184, 0.4), | |
| 0 0 40px rgba(124, 58, 172, 0.25) !important; | |
| position: relative; | |
| overflow: hidden; | |
| background-size: 200% 200% !important; | |
| animation: gradientShift 3s ease infinite; | |
| } | |
| /* Animated gradient background */ | |
| @keyframes gradientShift { | |
| 0% { | |
| background-position: 0% 50%; | |
| } | |
| 50% { | |
| background-position: 100% 50%; | |
| } | |
| 100% { | |
| background-position: 0% 50%; | |
| } | |
| } | |
| /* Ripple effect */ | |
| .stButton > button::before { | |
| content: ""; | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| width: 0; | |
| height: 0; | |
| border-radius: 50%; | |
| background: rgba(255, 255, 255, 0.3); | |
| transform: translate(-50%, -50%); | |
| transition: width 0.6s, height 0.6s; | |
| } | |
| /* Shimmer effect */ | |
| .stButton > button::after { | |
| content: ""; | |
| position: absolute; | |
| top: -50%; | |
| left: -50%; | |
| width: 200%; | |
| height: 200%; | |
| background: linear-gradient( | |
| 45deg, | |
| transparent 30%, | |
| rgba(255, 255, 255, 0.2) 50%, | |
| transparent 70% | |
| ); | |
| animation: shimmer 3s infinite; | |
| } | |
| @keyframes shimmer { | |
| 0% { | |
| transform: translateX(-100%) translateY(-100%) rotate(45deg); | |
| } | |
| 100% { | |
| transform: translateX(100%) translateY(100%) rotate(45deg); | |
| } | |
| } | |
| /* Text styling */ | |
| .stButton > button span, | |
| .stButton > button p, | |
| .stButton > button div { | |
| color: #FFFFFF !important; | |
| position: relative; | |
| z-index: 2; | |
| font-weight: 700 !important; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-4px) scale(1.03) !important; | |
| box-shadow: | |
| 0 10px 40px rgba(0, 153, 184, 0.6), | |
| 0 0 80px rgba(124, 58, 172, 0.4), | |
| inset 0 0 30px rgba(255, 255, 255, 0.15) !important; | |
| animation: gradientShift 1.5s ease infinite, buttonPulse 1s ease infinite; | |
| } | |
| @keyframes buttonPulse { | |
| 0%, 100% { | |
| box-shadow: | |
| 0 10px 40px rgba(0, 153, 184, 0.6), | |
| 0 0 80px rgba(124, 58, 172, 0.4); | |
| } | |
| 50% { | |
| box-shadow: | |
| 0 10px 50px rgba(0, 153, 184, 0.8), | |
| 0 0 100px rgba(124, 58, 172, 0.6); | |
| } | |
| } | |
| .stButton > button:hover::before { | |
| width: 400px; | |
| height: 400px; | |
| } | |
| .stButton > button:active { | |
| transform: translateY(-2px) scale(0.98) !important; | |
| } | |
| /* Selectbox & Input */ | |
| .stSelectbox > div > div, | |
| .stNumberInput > div > div > input { | |
| background: rgba(255, 255, 255, 0.04) !important; | |
| border: 1px solid var(--glass-border) !important; | |
| border-radius: 10px !important; | |
| color: #F9FAFB !important; | |
| font-family: 'Inter', sans-serif; | |
| transition: all 0.3s ease !important; | |
| } | |
| .stSelectbox > div > div:hover, | |
| .stNumberInput > div > div > input:hover { | |
| background: rgba(255, 255, 255, 0.06) !important; | |
| border-color: var(--primary) !important; | |
| } | |
| .stSelectbox > div > div:focus, | |
| .stNumberInput > div > div > input:focus { | |
| border-color: var(--primary) !important; | |
| box-shadow: 0 0 0 3px rgba(0, 217, 255, 0.15) !important; | |
| } | |
| /* Widget Labels */ | |
| div[data-testid="stWidgetLabel"] > label { | |
| font-size: 1rem !important; | |
| font-weight: 600 !important; | |
| color: #FFFFFF !important; | |
| margin-bottom: 0.5rem !important; | |
| letter-spacing: 0.02em; | |
| } | |
| /* Checkboxes */ | |
| .stCheckbox { | |
| background: var(--glass); | |
| border: 1px solid var(--glass-border); | |
| border-radius: 10px; | |
| padding: 0.85rem 1.2rem; | |
| margin-bottom: 0.75rem; | |
| transition: all 0.2s ease; | |
| backdrop-filter: blur(10px); | |
| } | |
| .stCheckbox:hover { | |
| border-color: var(--primary); | |
| background: rgba(0, 217, 255, 0.08); | |
| transform: translateX(4px); | |
| } | |
| .stCheckbox label { | |
| color: #F9FAFB !important; | |
| font-weight: 500; | |
| font-size: 1rem; | |
| } | |
| /* Radio Buttons */ | |
| .stRadio > div { | |
| background: var(--glass); | |
| border: 1px solid var(--glass-border); | |
| border-radius: 12px; | |
| padding: 1.2rem; | |
| backdrop-filter: blur(10px); | |
| } | |
| .stRadio label { | |
| color: #F9FAFB !important; | |
| font-weight: 500; | |
| } | |
| /* Expander */ | |
| .streamlit-expanderHeader { | |
| background: var(--glass) !important; | |
| backdrop-filter: blur(10px) !important; | |
| border: 1px solid var(--glass-border) !important; | |
| border-radius: 10px !important; | |
| color: var(--primary) !important; | |
| font-weight: 600 !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .streamlit-expanderHeader:hover { | |
| border-color: var(--primary) !important; | |
| background: rgba(0, 217, 255, 0.08) !important; | |
| transform: translateX(4px); | |
| } | |
| .streamlit-expanderContent { | |
| background: rgba(18, 24, 43, 0.4); | |
| border: 1px solid var(--glass-border); | |
| border-top: none; | |
| border-radius: 0 0 10px 10px; | |
| backdrop-filter: blur(10px); | |
| } | |
| /* Metrics */ | |
| [data-testid="stMetricValue"] { | |
| color: var(--primary); | |
| font-family: 'JetBrains Mono', monospace; | |
| font-size: 2rem; | |
| font-weight: 800; | |
| } | |
| [data-testid="stMetricLabel"] { | |
| color: #E5E7EB; | |
| font-size: 0.9rem; | |
| font-weight: 600; | |
| text-transform: uppercase; | |
| letter-spacing: 0.08em; | |
| } | |
| .stMetric { | |
| background: var(--glass) !important; | |
| backdrop-filter: blur(20px) !important; | |
| border: 1px solid var(--glass-border) !important; | |
| border-radius: 16px !important; | |
| padding: 1.5rem !important; | |
| transition: all 0.3s cubic-bezier(0.16, 1, 0.3, 1) !important; | |
| box-shadow: | |
| 0 4px 16px rgba(0, 0, 0, 0.2), | |
| inset 0 1px 0 rgba(255, 255, 255, 0.08) !important; | |
| } | |
| .stMetric:hover { | |
| transform: translateY(-4px) !important; | |
| border-color: var(--primary) !important; | |
| box-shadow: | |
| 0 8px 32px rgba(0, 0, 0, 0.3), | |
| 0 0 30px rgba(0, 217, 255, 0.2) !important; | |
| } | |
| /* DataFrames */ | |
| .stDataFrame { | |
| background: var(--glass) !important; | |
| backdrop-filter: blur(20px) !important; | |
| border: 1px solid var(--glass-border) !important; | |
| border-radius: 12px !important; | |
| overflow: hidden; | |
| box-shadow: | |
| 0 8px 32px rgba(0, 0, 0, 0.3), | |
| inset 0 1px 0 rgba(255, 255, 255, 0.08); | |
| } | |
| .stDataFrame thead tr th { | |
| background: linear-gradient(135deg, rgba(0, 217, 255, 0.12), rgba(168, 85, 247, 0.12)) !important; | |
| color: #FFFFFF !important; | |
| font-weight: 700 !important; | |
| font-size: 0.9rem !important; | |
| text-transform: uppercase; | |
| letter-spacing: 0.05em; | |
| padding: 1rem 0.8rem !important; | |
| border: none !important; | |
| } | |
| .stDataFrame tbody tr:hover { | |
| background: rgba(0, 217, 255, 0.06) !important; | |
| } | |
| .stDataFrame tbody tr td { | |
| color: #E5E7EB !important; | |
| font-size: 0.95rem !important; | |
| padding: 0.9rem 0.8rem !important; | |
| font-family: 'JetBrains Mono', monospace; | |
| border-color: rgba(255, 255, 255, 0.05) !important; | |
| } | |
| /* Progress Bar */ | |
| .stProgress > div > div { | |
| background: linear-gradient(90deg, var(--primary) 0%, var(--secondary) 100%) !important; | |
| border-radius: 10px; | |
| animation: progressPulse 2s ease infinite; | |
| } | |
| @keyframes progressPulse { | |
| 0%, 100% { | |
| opacity: 1; | |
| } | |
| 50% { | |
| opacity: 0.8; | |
| } | |
| } | |
| /* Status Messages */ | |
| .stSuccess { | |
| background: rgba(16, 185, 129, 0.1) !important; | |
| border-left: 4px solid var(--success) !important; | |
| color: var(--success) !important; | |
| border-radius: 8px; | |
| } | |
| .stWarning { | |
| background: rgba(245, 158, 11, 0.1) !important; | |
| border-left: 4px solid var(--warning) !important; | |
| color: var(--warning) !important; | |
| border-radius: 8px; | |
| } | |
| .stInfo { | |
| background: rgba(0, 217, 255, 0.1) !important; | |
| border-left: 4px solid var(--primary) !important; | |
| color: var(--primary) !important; | |
| border-radius: 8px; | |
| } | |
| .stError { | |
| background: rgba(239, 68, 68, 0.1) !important; | |
| border-left: 4px solid var(--error) !important; | |
| color: var(--error) !important; | |
| border-radius: 8px; | |
| } | |
| /* Plotly Charts */ | |
| .js-plotly-plot { | |
| border: 1px solid var(--glass-border); | |
| border-radius: 16px; | |
| overflow: hidden; | |
| box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3); | |
| } | |
| /* Divider */ | |
| hr { | |
| border-color: var(--dark-elevated); | |
| margin: 2.5rem 0; | |
| opacity: 0.5; | |
| } | |
| /* Scrollbar */ | |
| ::-webkit-scrollbar { | |
| width: 10px; | |
| height: 10px; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: var(--dark-surface); | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background: linear-gradient(180deg, var(--primary), var(--secondary)); | |
| border-radius: 10px; | |
| } | |
| ::-webkit-scrollbar-thumb:hover { | |
| background: linear-gradient(180deg, var(--secondary), var(--primary)); | |
| } | |
| /* Animations */ | |
| @keyframes slideDown { | |
| from { | |
| opacity: 0; | |
| transform: translateY(-20px); | |
| } | |
| to { | |
| opacity: 1; | |
| transform: translateY(0); | |
| } | |
| } | |
| @keyframes slideUp { | |
| from { | |
| opacity: 0; | |
| transform: translateY(20px); | |
| } | |
| to { | |
| opacity: 1; | |
| transform: translateY(0); | |
| } | |
| } | |
| @keyframes fadeIn { | |
| from { | |
| opacity: 0; | |
| } | |
| to { | |
| opacity: 1; | |
| } | |
| } | |
| /* Responsive */ | |
| @media (max-width: 768px) { | |
| .block-container { | |
| padding: 1.5rem; | |
| } | |
| h1 { | |
| font-size: 2rem !important; | |
| } | |
| h2 { | |
| font-size: 1.5rem !important; | |
| } | |
| .hero-intro { | |
| padding: 1.5rem; | |
| } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # 1. Page Config | |
| st.set_page_config( | |
| page_title="VN30 Forecast - Model Comparison", | |
| layout="wide" | |
| ) | |
| # Sitebar navigation | |
| st.sidebar.title("📂 VN30 Forecast") | |
| st.sidebar.markdown("---") | |
| menu = st.sidebar.radio( | |
| "Menu", | |
| [ | |
| "📊 So sánh model", | |
| "🧪 Demo sản phẩm", | |
| "📘 Tài liệu" | |
| ] | |
| ) | |
| if menu == "📊 So sánh model": | |
| # 2. Load data | |
| def load_data(): | |
| df = pd.read_csv("VN30_Train.csv") | |
| df["time"] = pd.to_datetime(df["time"]) | |
| df = df.sort_values(["symbol", "time"]) | |
| return df | |
| df_train = load_data() | |
| symbols = sorted(df_train["symbol"].unique()) | |
| def load_test_data(): | |
| df = pd.read_csv("VN30_Test.csv") | |
| df["time"] = pd.to_datetime(df["time"]) | |
| df = df.sort_values(["symbol", "time"]) | |
| return df | |
| df_test = load_test_data() | |
| def load_news(): | |
| df = pd.read_csv("VN30_news_with_sentiment.csv") | |
| df["date"] = pd.to_datetime(df["date"]) | |
| return df | |
| df_news = load_news() | |
| # 3. UI | |
| st.markdown('<h1 style="text-align: center;"><span class="emoji">📈</span> VN30 Stock Forecast - Model Comparison</h1>', unsafe_allow_html=True) | |
| a1, a2, a3 = st.columns([1, 2, 1]) | |
| with a2: | |
| st.image( | |
| "VN30_Model_Thumb.png", | |
| width=1000 | |
| ) | |
| st.markdown("---") | |
| with st.expander("🎯 **Mục tiêu dự án** (bấm để xem chi tiết)"): | |
| st.markdown( | |
| """ | |
| <div class="hero-intro"> | |
| <h3 style="margin-top: 0;">🎯 Mục tiêu dự án</h3> | |
| <p>Dự án này nhằm xây dựng một <strong>ứng dụng dự báo giá cổ phiếu VN30 theo ngày</strong> bằng nhiều phương pháp forecasting khác nhau, bao gồm:</p> | |
| <ul style="margin: 1rem 0;"> | |
| <li><strong>Technical Analysis</strong>: Dự báo dựa trên xu hướng, động lượng và biến động giá thông qua các chỉ báo kỹ thuật.</li> | |
| <li><strong>SARIMA</strong>: Mô hình machine learning thống kê cổ điển cho chuỗi thời gian, kết hợp xu hướng, mùa vụ và biến ngoại sinh.</li> | |
| <li><strong>PatchTST</strong>: Mô hình deep learning - Transformer time-series dùng patch để tối ưu bộ nhớ.</li> | |
| <li><strong>TimeFM</strong>: Foundation model cho time series, model này chỉ nên ở mức tham khảo vì khả năng dự đoán không tốt.</li> | |
| <li><strong>TimeGPT</strong>: Generative forecasting model.</li> | |
| <li><strong>TimeGPT with news</strong>: TimeGPT kết hợp qwen3 xuất ra sentiment từ tin tức (beta), tin tức được lấy từ https://cafef.vn/</li> | |
| </ul> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Xuất ra ngày hiện tại | |
| today_str = datetime.now().strftime("%d/%m/%Y") | |
| st.markdown( | |
| f""" | |
| <div class="date-display"> | |
| <p>📅 Ngày hiện tại: <b>{today_str}</b></p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # 4. Crawl data | |
| def crawl_vn30_data(start_date, end_date, filename): | |
| symbols = [ | |
| "ACB","DGC","BCM","BID","FPT","HDB","HPG","LPB","MSN","MBB", | |
| "MWG","PLX","GAS","SAB","STB","SHB","SSB","SSI","TCB","TPB", | |
| "VCB","CTG","VJC","VIB","GVR","VNM","VRE","VIC","VHM","VPB" | |
| ] | |
| all_dfs = [] | |
| progress = st.progress(0) | |
| status = st.empty() | |
| for i, symbol in enumerate(symbols): | |
| try: | |
| stock = Vnstock().stock(symbol=symbol, source="VCI") | |
| df = stock.quote.history( | |
| start=start_date, | |
| end=end_date, | |
| interval="1D" | |
| ) | |
| if df.empty: | |
| status.warning(f"{symbol} không có dữ liệu") | |
| continue | |
| # Feature engineering (GIỮ NGUYÊN) | |
| df["estimated_value"] = ( | |
| (df["open"] + df["high"] + df["low"] + df["close"]) / 4 | |
| ) * df["volume"] | |
| df["+/- price percent"] = df["close"].pct_change().mul(100).round(2) | |
| df["symbol"] = symbol | |
| all_dfs.append(df) | |
| status.info(f"Đã crawl xong {symbol}") | |
| except Exception as e: | |
| status.error(f"Lỗi {symbol}: {e}") | |
| progress.progress((i + 1) / len(symbols)) | |
| time.sleep(10) | |
| if len(all_dfs) == 0: | |
| st.error("Không crawl được dữ liệu nào") | |
| return | |
| final_df = pd.concat(all_dfs, ignore_index=True) | |
| final_df.to_csv(filename, index=False) | |
| st.success(f"Đã lưu {filename}") | |
| st.markdown("## 📥 Crawl dữ liệu VN30 thành 2 tập train và test") | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| if st.button("⬇️ Crawl Train Data (Từ 2022 đến 2 tuần trước hiện tại)"): | |
| with st.spinner("Đang crawl Train data..."): | |
| end_date = (datetime.now() - timedelta(days=14)).strftime("%Y-%m-%d") | |
| crawl_vn30_data( | |
| start_date="2022-01-01", | |
| end_date=end_date, | |
| filename="VN30_Train.csv" | |
| ) | |
| st.cache_data.clear() | |
| st.rerun() | |
| with c2: | |
| if st.button("⬇️ Crawl Test Data (Từ 2 tuần trước đến hiện tại)"): | |
| with st.spinner("Đang crawl Test data..."): | |
| end_date = datetime.now().strftime("%Y-%m-%d") | |
| crawl_vn30_data( | |
| start_date="2022-01-01", | |
| end_date=end_date, | |
| filename="VN30_Test.csv" | |
| ) | |
| st.cache_data.clear() | |
| st.rerun() | |
| # 5. Làm button để cho người dùng chọn cổ phiếu và model | |
| st.markdown("## 📌 Chọn cổ phiếu và ngày sắp tới để dự đoán") | |
| left, right = st.columns([1, 1]) | |
| with left: | |
| symbol = st.selectbox( | |
| "Bạn hãy chọn cổ phiếu", | |
| symbols | |
| ) | |
| with right: | |
| horizon = st.number_input( | |
| "Bạn muốn forecast trong bao nhiêu ngày sắp tới", | |
| min_value=1, | |
| max_value=30, | |
| value=14, | |
| step=1 | |
| ) | |
| st.markdown("## 🧠 Chọn model để dự đoán") | |
| MODEL_OPTIONS = [ | |
| "Technical Analysis", | |
| "SARIMA", | |
| "PatchTST", | |
| "TimeFM", | |
| "TimeGPT (Recommended)", | |
| "TimeGPT with news (Beta)" | |
| ] | |
| selected_models = [] | |
| cols = st.columns(3) | |
| for i, m in enumerate(MODEL_OPTIONS): | |
| with cols[i % 3]: | |
| if st.checkbox(m, value=(m == "TimeGPT (Recommended)")): | |
| selected_models.append(m) | |
| if len(selected_models) == 0: | |
| st.warning("⚠️ Bạn cần chọn ít nhất 1 model") | |
| # 6. Load Models | |
| def load_timefm(): | |
| model = TimesFmModelForPrediction.from_pretrained( | |
| "google/timesfm-2.0-500m-pytorch", | |
| dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| return model | |
| def load_timegpt(): | |
| return NixtlaClient( | |
| api_key="nixak-zWQjbVl9QCc6eIFL3DDbaBXi09bnPKsa5jdUU7Q8izPpn3eYl0rZPWLLs8NI597PT0VzIODhPUmKzkMc" | |
| ) | |
| def load_patchtst(): | |
| config = PatchTSTConfig( | |
| context_length=31, | |
| prediction_length=3, | |
| input_size=1, | |
| patch_len=6, | |
| stride=3, | |
| d_model=96, | |
| num_hidden_layers=4, | |
| num_attention_heads=4, | |
| dropout=0.15 | |
| ) | |
| model = PatchTSTForPrediction(config) | |
| model.load_state_dict( | |
| torch.load("best_patchtst_vn30.pt", map_location="cpu") | |
| ) | |
| model.eval() | |
| scalers = joblib.load("patchtst_scalers_vn30.pkl") | |
| return model, scalers | |
| # 7. Forecast Function | |
| def compute_ci(df_sym, preds): | |
| returns = df_sym["close"].diff().dropna() | |
| vol = returns.rolling(20, min_periods=5).std().iloc[-1] | |
| z = 1.96 | |
| return preds - z * vol, preds + z * vol | |
| def forecast_timefm(symbol, horizon): | |
| model = load_timefm() | |
| df_sym = df_train[df_train["symbol"] == symbol] | |
| series = df_sym["close"].astype(float).values | |
| past_tensor = ( | |
| torch.from_numpy(series) | |
| .to(dtype=torch.bfloat16, device=model.device) | |
| .unsqueeze(0) | |
| ) | |
| freq_tensor = torch.tensor([0], dtype=torch.long, device=model.device) | |
| with torch.no_grad(): | |
| outputs = model( | |
| past_values=past_tensor, | |
| freq=freq_tensor, | |
| return_dict=True | |
| ) | |
| preds = outputs.mean_predictions[0].float().cpu().numpy()[:horizon] | |
| future_time = pd.bdate_range( | |
| start=df_sym["time"].iloc[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| lo, hi = compute_ci(df_sym, preds) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": preds, | |
| "lower_95": lo, | |
| "upper_95": hi | |
| }) | |
| def forecast_timegpt(symbol, horizon): | |
| client = load_timegpt() | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.rename(columns={"time": "ds", "close": "y"})[["ds", "y"]] | |
| df_sym = df_sym.sort_values("ds").drop_duplicates(subset="ds") | |
| # Create full business-day range | |
| full_range = pd.date_range( | |
| start=df_sym["ds"].min(), | |
| end=df_sym["ds"].max(), | |
| freq="B" | |
| ) | |
| df_sym = ( | |
| df_sym | |
| .set_index("ds") | |
| .reindex(full_range) | |
| .rename_axis("ds") | |
| .reset_index() | |
| ) | |
| # Fill missing prices | |
| df_sym["y"] = df_sym["y"].ffill().bfill() | |
| df_ts = df_sym.rename( | |
| columns={"time": "ds", "close": "y"} | |
| )[["ds", "y"]] | |
| df_ts["unique_id"] = symbol | |
| df_fc = client.forecast( | |
| df=df_ts, | |
| h=horizon, | |
| freq="B", | |
| level=[95] | |
| ) | |
| return df_fc.rename(columns={ | |
| "ds": "time", | |
| "TimeGPT": "forecast_close", | |
| "TimeGPT-lo-95": "lower_95", | |
| "TimeGPT-hi-95": "upper_95" | |
| })[["time", "forecast_close", "lower_95", "upper_95"]] | |
| def forecast_timegpt_news(symbol, horizon): | |
| client = load_timegpt() | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.rename(columns={"time": "ds", "close": "y"})[["ds", "y"]] | |
| df_sym = df_sym.sort_values("ds").drop_duplicates(subset="ds") | |
| # Create full business-day range | |
| full_range = pd.date_range( | |
| start=df_sym["ds"].min(), | |
| end=df_sym["ds"].max(), | |
| freq="B" | |
| ) | |
| df_sym = ( | |
| df_sym | |
| .set_index("ds") | |
| .reindex(full_range) | |
| .rename_axis("ds") | |
| .reset_index() | |
| ) | |
| # Fill missing prices | |
| df_sym["y"] = df_sym["y"].ffill().bfill() | |
| # News | |
| df_news_sym = df_news[df_news["symbol"] == symbol].copy() | |
| df_news_sym["date"] = pd.to_datetime(df_news_sym["date"]) | |
| df_sent = ( | |
| df_news_sym | |
| .groupby("date")["sentiment"] | |
| .mean() | |
| .reset_index() | |
| ) | |
| df_sym["date"] = df_sym["ds"].dt.normalize() | |
| df_ts = df_sym.merge(df_sent, on="date", how="left") | |
| # Fill missing sentiment | |
| df_ts["sentiment"] = df_ts["sentiment"].fillna(0.0) | |
| df_ts["unique_id"] = symbol | |
| df_fc = client.forecast( | |
| df=df_ts[["unique_id", "ds", "y", "sentiment"]], | |
| h=horizon, | |
| freq="B", | |
| level=[95] | |
| ) | |
| return df_fc.rename(columns={ | |
| "ds": "time", | |
| "TimeGPT": "forecast_close", | |
| "TimeGPT-lo-95": "lower_95", | |
| "TimeGPT-hi-95": "upper_95" | |
| })[["time", "forecast_close", "lower_95", "upper_95"]] | |
| def forecast_patchtst(symbol, horizon): | |
| model, scalers = load_patchtst() | |
| df_sym = df_train[df_train["symbol"] == symbol] | |
| scaler = scalers[symbol] | |
| values = df_sym["close"].values.reshape(-1, 1) | |
| scaled = scaler.transform(values).flatten() | |
| context = scaled[-31:].copy() | |
| preds_scaled = [] | |
| with torch.no_grad(): | |
| for _ in range(horizon): | |
| xb = torch.tensor(context, dtype=torch.float32).unsqueeze(0).unsqueeze(-1) | |
| out = model(past_values=xb) | |
| next_pred = out.prediction_outputs[0, 0, 0].item() | |
| preds_scaled.append(next_pred) | |
| context = np.roll(context, -1) | |
| context[-1] = next_pred | |
| preds = scaler.inverse_transform( | |
| np.array(preds_scaled).reshape(-1, 1) | |
| ).flatten() | |
| future_time = pd.bdate_range( | |
| start=df_sym["time"].iloc[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| lo, hi = compute_ci(df_sym, preds) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": preds, | |
| "lower_95": lo, | |
| "upper_95": hi | |
| }) | |
| def forecast_sarima(symbol, horizon=7): | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.sort_values("time").drop_duplicates(subset="time") | |
| # Endogenous (price) | |
| ts = ( | |
| df_sym | |
| .set_index("time")["close"] | |
| .astype(float) | |
| .asfreq("B") | |
| .ffill() | |
| ) | |
| # Exogenous variables | |
| df_exog = df_sym.set_index("time").asfreq("B").ffill() | |
| df_exog["return"] = df_exog["close"].pct_change().shift(1) | |
| df_exog["volume"] = np.log1p(df_exog["volume"]) | |
| exog = df_exog[["return", "volume"]].fillna(0) | |
| # Đồng bộ index | |
| exog = exog.loc[ts.index] | |
| # Model | |
| model = SARIMAX( | |
| ts, | |
| exog=exog, | |
| order=(2, 0, 2), | |
| seasonal_order=(1, 0, 1, 5), | |
| enforce_stationarity=False, | |
| enforce_invertibility=False | |
| ) | |
| result = model.fit(disp=False, maxiter=50) | |
| # Future exog | |
| future_time = pd.bdate_range( | |
| start=ts.index[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| mean_ret = df_exog["return"].tail(20).mean() | |
| mean_vol = df_exog["volume"].tail(20).mean() | |
| future_exog = pd.DataFrame( | |
| { | |
| "return": np.full(horizon, mean_ret), | |
| "volume": np.full(horizon, mean_vol) | |
| }, | |
| index=future_time | |
| ) | |
| fc = result.get_forecast( | |
| steps=horizon, | |
| exog=future_exog | |
| ) | |
| conf_int = fc.conf_int(alpha=0.05) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": fc.predicted_mean.values, | |
| "lower_95": conf_int.iloc[:, 0].values, | |
| "upper_95": conf_int.iloc[:, 1].values | |
| }) | |
| def forecast_technical(symbol, horizon=7): | |
| df = df_train[df_train["symbol"] == symbol].copy() | |
| df = df.sort_values("time").reset_index(drop=True) | |
| # Indicators | |
| df["MA20"] = df["close"].rolling(20).mean() | |
| df["MA50"] = df["close"].rolling(50).mean() | |
| delta = df["close"].diff() | |
| gain = delta.clip(lower=0) | |
| loss = -delta.clip(upper=0) | |
| avg_gain = gain.rolling(14).mean() | |
| avg_loss = loss.rolling(14).mean() | |
| rs = avg_gain / avg_loss | |
| df["RSI"] = 100 - (100 / (1 + rs)) | |
| def signal(row): | |
| if row["close"] > row["MA20"] > row["MA50"] and row["RSI"] > 50: | |
| return 1 | |
| elif row["close"] < row["MA20"] < row["MA50"] and row["RSI"] < 50: | |
| return -1 | |
| else: | |
| return 0 | |
| df["signal"] = df.apply(signal, axis=1) | |
| # Forecast | |
| last_price = df["close"].iloc[-1] | |
| last_signal = df["signal"].iloc[-1] | |
| returns = df["close"].pct_change().rolling(10).std() | |
| vol = returns.iloc[-1] | |
| vol = 0 if np.isnan(vol) else min(vol, 0.01) | |
| future_time = pd.bdate_range( | |
| start=df["time"].iloc[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| prices = [] | |
| price = last_price | |
| for _ in range(horizon): | |
| if last_signal == 1: | |
| price *= (1 + vol) | |
| elif last_signal == -1: | |
| price *= (1 - vol) | |
| else: | |
| price *= (1 + 0.1 * vol) | |
| prices.append(price) | |
| preds = np.array(prices) | |
| lo, hi = compute_ci(df, preds) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": preds, | |
| "lower_95": lo, | |
| "upper_95": hi | |
| }) | |
| # 8. Rolling backtest | |
| def rolling_backtest(df_sym, forecast_fn, window=500, step=5): | |
| errors = [] | |
| for start in range(0, len(df_sym) - window - horizon, step): | |
| train_slice = df_sym.iloc[start:start + window] | |
| test_slice = df_sym.iloc[start + window:start + window + horizon] | |
| df_pred = forecast_fn( | |
| train_slice["symbol"].iloc[0], | |
| horizon | |
| ) | |
| y_true = test_slice["close"].values[:len(df_pred)] | |
| y_pred = df_pred["forecast_close"].values | |
| # MAE | |
| mae = mean_absolute_error(y_true, y_pred) | |
| # RMSE | |
| rmse = root_mean_squared_error(y_true, y_pred) | |
| # MASE (naive = yesterday) | |
| naive = train_slice["close"].diff().dropna() | |
| mae_naive = naive.abs().mean() | |
| mase = mae / mae_naive if mae_naive != 0 else np.nan | |
| # Directional Accuracy | |
| actual_dir = np.sign(np.diff(y_true)) | |
| pred_dir = np.sign(np.diff(y_pred)) | |
| da = (actual_dir == pred_dir).mean() if len(actual_dir) > 0 else np.nan | |
| errors.append({ | |
| "MAE": mae, | |
| "RMSE": rmse, | |
| "MASE": mase, | |
| "DA": da | |
| }) | |
| return pd.DataFrame(errors) | |
| def backtest_sarima(symbol, n_test=7): | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.sort_values("time").drop_duplicates(subset="time") | |
| # Endogenous | |
| ts = ( | |
| df_sym | |
| .set_index("time")["close"] | |
| .astype(float) | |
| .asfreq("B") | |
| .ffill() | |
| ) | |
| # Exogenous | |
| df_exog = df_sym.set_index("time").asfreq("B").ffill() | |
| df_exog["return"] = df_exog["close"].pct_change().shift(1) | |
| df_exog["volume"] = np.log1p(df_exog["volume"]) | |
| exog = df_exog[["return", "volume"]].fillna(0) | |
| # Đồng bộ index | |
| exog = exog.loc[ts.index] | |
| # Train / Test split | |
| ts_train = ts.iloc[:-n_test] | |
| ts_test = ts.iloc[-n_test:] | |
| exog_train = exog.loc[ts_train.index] | |
| exog_test = exog.loc[ts_test.index] | |
| # Model | |
| model = SARIMAX( | |
| ts_train, | |
| exog=exog_train, | |
| order=(2, 0, 2), | |
| seasonal_order=(1, 0, 1, 5), | |
| enforce_stationarity=False, | |
| enforce_invertibility=False | |
| ) | |
| result = model.fit(disp=False, maxiter=50) | |
| fc = result.get_forecast( | |
| steps=n_test, | |
| exog=exog_test | |
| ) | |
| y_pred = fc.predicted_mean | |
| y_true = ts_test | |
| # Metrics | |
| mae = mean_absolute_error(y_true, y_pred) | |
| rmse = root_mean_squared_error(y_true, y_pred) | |
| naive = ts_train.shift(1).dropna() | |
| mae_naive = mean_absolute_error( | |
| ts_train.loc[naive.index], | |
| naive | |
| ) | |
| mase = mae / mae_naive if mae_naive != 0 else np.nan | |
| actual_dir = np.sign(y_true.diff().iloc[1:]) | |
| pred_dir = np.sign(y_pred.diff().iloc[1:]) | |
| da = (actual_dir == pred_dir).mean() | |
| return { | |
| "MAE": mae, | |
| "RMSE": rmse, | |
| "MASE": mase, | |
| "DA": da | |
| } | |
| def backtest_technical(symbol, n_test=7): | |
| df = df_train[df_train["symbol"] == symbol].copy() | |
| df = df.sort_values("time").reset_index(drop=True) | |
| train = df.iloc[:-n_test] | |
| test = df.iloc[-n_test:] | |
| df_fc = forecast_technical(symbol, horizon=n_test) | |
| y_true = test["close"].values | |
| y_pred = df_fc["forecast_close"].values | |
| # MAE | |
| mae = mean_absolute_error(y_true, y_pred) | |
| # RMSE | |
| rmse = root_mean_squared_error(y_true, y_pred) | |
| # MASE | |
| naive = train["close"].shift(1).dropna() | |
| mae_naive = np.abs(naive.values - train["close"].iloc[1:].values).mean() | |
| mase = mae / mae_naive if mae_naive != 0 else np.nan | |
| # Directional Accuracy | |
| actual_dir = np.sign(np.diff(y_true)) | |
| pred_dir = np.sign(np.diff(y_pred)) | |
| da = (actual_dir == pred_dir).mean() | |
| return { | |
| "MAE": mae, | |
| "RMSE": rmse, | |
| "MASE": mase, | |
| "DA": da | |
| } | |
| # 9. Run Forecast | |
| if st.button("🚀 Forecast") and len(selected_models) > 0: | |
| progress = st.progress(0) | |
| status = st.empty() | |
| all_forecasts = {} | |
| all_metrics = {} | |
| total = len(selected_models) | |
| for i, model_name in enumerate(selected_models): | |
| status.info(f"⏳ Đang chạy model: **{model_name}**") | |
| # Forecast | |
| if model_name == "TimeFM": | |
| df_fc = forecast_timefm(symbol, horizon) | |
| bt_fn = forecast_timefm | |
| df_bt = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| elif model_name == "TimeGPT (Recommended)": | |
| df_fc = forecast_timegpt(symbol, horizon) | |
| bt_fn = forecast_timegpt | |
| df_bt = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| elif model_name == "TimeGPT with news (Beta)": | |
| df_fc = forecast_timegpt_news(symbol, horizon) | |
| bt_fn = forecast_timegpt_news | |
| df_bt = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| elif model_name == "SARIMA": | |
| df_fc = forecast_sarima(symbol, horizon) | |
| df_bt = pd.DataFrame([backtest_sarima(symbol, n_test=horizon)]) | |
| elif model_name == "Technical Analysis": | |
| df_fc = forecast_technical(symbol, horizon) | |
| df_bt = pd.DataFrame([backtest_technical(symbol, n_test=horizon)]) | |
| else: # PatchTST | |
| df_fc = forecast_patchtst(symbol, horizon) | |
| bt_fn = forecast_patchtst | |
| df_bt = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| all_forecasts[model_name] = df_fc | |
| all_metrics[model_name] = df_bt.mean(numeric_only=True) | |
| progress.progress((i + 1) / total) | |
| status.success("Forecast hoàn tất!") | |
| st.session_state.all_forecasts = all_forecasts | |
| st.session_state.all_metrics = pd.DataFrame(all_metrics).T.reset_index().rename( | |
| columns={"index": "Model"} | |
| ) | |
| # 11. Plot | |
| if "all_forecasts" in st.session_state: | |
| st.subheader(f"📊 Forecast plot {symbol} – {horizon} ngày") | |
| model_to_view = st.selectbox( | |
| "🎛️ Chọn model để hiển thị", | |
| list(st.session_state.all_forecasts.keys()) | |
| ) | |
| df_fc = st.session_state.all_forecasts[model_to_view] | |
| df_hist = df_train[df_train["symbol"] == symbol] | |
| df_test_sym = df_test[df_test["symbol"] == symbol] | |
| last_train_time = df_hist["time"].iloc[-1] | |
| df_actual_future = df_test_sym[ | |
| df_test_sym["time"] >= last_train_time | |
| ].iloc[:horizon] | |
| fig = go.Figure() | |
| # Historical | |
| fig.add_trace(go.Scatter( | |
| x=df_hist["time"], | |
| y=df_hist["close"], | |
| mode="lines", | |
| name="Historical", | |
| line=dict(width=2.5) | |
| )) | |
| # Actual | |
| if len(df_actual_future) > 0: | |
| fig.add_trace(go.Scatter( | |
| x=df_actual_future["time"], | |
| y=df_actual_future["close"], | |
| mode="lines+markers", | |
| name="Actual" | |
| )) | |
| # CI | |
| fig.add_trace(go.Scatter( | |
| x=df_fc["time"], | |
| y=df_fc["upper_95"], | |
| mode="lines", | |
| line=dict(width=0), | |
| name="95% CI", | |
| showlegend=True | |
| ) | |
| ) | |
| fig.add_trace(go.Scatter( | |
| x=df_fc["time"], | |
| y=df_fc["lower_95"], | |
| mode="lines", | |
| line=dict(width=0), | |
| fill="tonexty", | |
| fillcolor="rgba(168, 85, 247, 0.35)", # tím xịn | |
| name=None, | |
| showlegend=False | |
| ) | |
| ) | |
| # Forecast | |
| fig.add_trace(go.Scatter( | |
| x=df_fc["time"], | |
| y=df_fc["forecast_close"], | |
| mode="lines+markers", | |
| name=f"Forecast ({model_to_view})" | |
| )) | |
| fig.update_layout( | |
| template="plotly_dark", | |
| hovermode="x unified", | |
| dragmode="pan", | |
| ) | |
| st.plotly_chart( | |
| fig, | |
| use_container_width=True, | |
| config={"scrollZoom": True} | |
| ) | |
| # 12. Metrics | |
| if "all_metrics" in st.session_state: | |
| st.subheader("📐 So sánh Metrics") | |
| metric_options = { | |
| "MAE (giá trị càng nhỏ model càng tốt)": "MAE", | |
| "RMSE (giá trị càng nhỏ model càng tốt)": "RMSE", | |
| "MASE (giá trị càng nhỏ model càng tốt)": "MASE", | |
| "DA (giá trị càng lớn model càng tốt, trên 0.5 là model rất tốt)": "DA", | |
| } | |
| metric_label = st.selectbox( | |
| "🎛️ Chọn metric để so sánh", | |
| list(metric_options.keys()) | |
| ) | |
| metric_to_view = metric_options[metric_label] | |
| df_m = st.session_state.all_metrics.sort_values( | |
| metric_to_view, | |
| ascending=(metric_to_view != "DA") # DA càng cao càng tốt | |
| ) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=df_m["Model"], | |
| y=df_m[metric_to_view], | |
| text=df_m[metric_to_view].round(3), | |
| textposition="auto" | |
| )) | |
| fig.update_layout( | |
| template="plotly_dark", | |
| yaxis_title=metric_to_view, | |
| xaxis_title="Model", | |
| dragmode="pan", | |
| ) | |
| st.plotly_chart( | |
| fig, | |
| use_container_width=True, | |
| config={"scrollZoom": True} | |
| ) | |
| if menu == "🧪 Demo sản phẩm": | |
| # 1. Page Config | |
| st.set_page_config( | |
| page_title="VN30 Forecast - Demo Product", | |
| layout="wide" | |
| ) | |
| # 2. Load data | |
| def load_data(): | |
| df = pd.read_csv("VN30_Test.csv") | |
| df["time"] = pd.to_datetime(df["time"]) | |
| df = df.sort_values(["symbol", "time"]) | |
| return df | |
| df_train = load_data() | |
| symbols = sorted(df_train["symbol"].unique()) | |
| def load_news(): | |
| df = pd.read_csv("VN30_news_with_sentiment.csv") | |
| df["date"] = pd.to_datetime(df["date"]) | |
| return df | |
| df_news = load_news() | |
| # 3. UI | |
| st.markdown('<h1 style="text-align: center;"><span class="emoji">📈</span> VN30 Stock Forecast - Demo Product</h1>', unsafe_allow_html=True) | |
| a1, a2, a3 = st.columns([1, 2, 1]) | |
| with a2: | |
| st.image( | |
| "VN30_Product_Thumb.png", | |
| width=1000 | |
| ) | |
| st.markdown("---") | |
| with st.expander("🎯 **Mục tiêu dự án** (bấm để xem chi tiết)"): | |
| st.markdown( | |
| """ | |
| <div class="hero-intro"> | |
| <h3 style="margin-top: 0;">🎯 Mục tiêu dự án</h3> | |
| <p>Dự án này nhằm xây dựng một <strong>ứng dụng dự báo giá cổ phiếu VN30 theo ngày</strong> bằng nhiều phương pháp forecasting khác nhau, bao gồm:</p> | |
| <ul style="margin: 1rem 0;"> | |
| <li><strong>Technical Analysis</strong>: Dự báo dựa trên xu hướng, động lượng và biến động giá thông qua các chỉ báo kỹ thuật.</li> | |
| <li><strong>SARIMA</strong>: Mô hình machine learning thống kê cổ điển cho chuỗi thời gian, kết hợp xu hướng, mùa vụ và biến ngoại sinh.</li> | |
| <li><strong>PatchTST</strong>: Mô hình deep learning - Transformer time-series dùng patch để tối ưu bộ nhớ.</li> | |
| <li><strong>TimeFM</strong>: Foundation model cho time series, model này chỉ nên ở mức tham khảo vì khả năng dự đoán không tốt.</li> | |
| <li><strong>TimeGPT</strong>: Generative forecasting model.</li> | |
| <li><strong>TimeGPT with news</strong>: TimeGPT kết hợp qwen3 xuất ra sentiment từ tin tức (beta), tin tức được lấy từ https://cafef.vn/</li> | |
| </ul> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Xuất ra ngày hiện tại | |
| today_str = datetime.now().strftime("%d/%m/%Y") | |
| st.markdown( | |
| f""" | |
| <div class="date-display"> | |
| <p>📅 Ngày hiện tại: <b>{today_str}</b></p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # 4. Crawl data | |
| def crawl_vn30_data(start_date, end_date, filename): | |
| symbols = [ | |
| "ACB","DGC","BCM","BID","FPT","HDB","HPG","LPB","MSN","MBB", | |
| "MWG","PLX","GAS","SAB","STB","SHB","SSB","SSI","TCB","TPB", | |
| "VCB","CTG","VJC","VIB","GVR","VNM","VRE","VIC","VHM","VPB" | |
| ] | |
| all_dfs = [] | |
| progress = st.progress(0) | |
| status = st.empty() | |
| for i, symbol in enumerate(symbols): | |
| try: | |
| stock = Vnstock().stock(symbol=symbol, source="VCI") | |
| df = stock.quote.history( | |
| start=start_date, | |
| end=end_date, | |
| interval="1D" | |
| ) | |
| if df.empty: | |
| status.warning(f"{symbol} không có dữ liệu") | |
| continue | |
| # Feature engineering (GIỮ NGUYÊN) | |
| df["estimated_value"] = ( | |
| (df["open"] + df["high"] + df["low"] + df["close"]) / 4 | |
| ) * df["volume"] | |
| df["+/- price percent"] = df["close"].pct_change().mul(100).round(2) | |
| df["symbol"] = symbol | |
| all_dfs.append(df) | |
| status.info(f"Đã crawl xong {symbol}") | |
| except Exception as e: | |
| status.error(f"Lỗi {symbol}: {e}") | |
| progress.progress((i + 1) / len(symbols)) | |
| time.sleep(10) | |
| if len(all_dfs) == 0: | |
| st.error("Không crawl được dữ liệu nào") | |
| return | |
| final_df = pd.concat(all_dfs, ignore_index=True) | |
| final_df.to_csv(filename, index=False) | |
| st.success(f"Đã lưu {filename}") | |
| st.markdown("## 📥 Crawl dữ liệu VN30") | |
| if st.button("⬇️ Crawl Data (Từ 2022 đến hiện tại)"): | |
| with st.spinner("Đang crawl data..."): | |
| end_date = datetime.now().strftime("%Y-%m-%d") | |
| crawl_vn30_data( | |
| start_date="2022-01-01", | |
| end_date=end_date, | |
| filename="VN30_Test.csv" | |
| ) | |
| st.cache_data.clear() | |
| st.rerun() | |
| # 5. Làm button để cho người dùng chọn cổ phiếu và model | |
| st.markdown("## 📌 Chọn cổ phiếu và ngày sắp tới để dự đoán") | |
| left, right = st.columns([1, 1]) | |
| with left: | |
| symbol = st.selectbox( | |
| "Bạn hãy chọn cổ phiếu", | |
| symbols | |
| ) | |
| with right: | |
| horizon = st.number_input( | |
| "Bạn muốn forecast trong bao nhiêu ngày sắp tới", | |
| min_value=1, | |
| max_value=30, | |
| value=14, | |
| step=1 | |
| ) | |
| st.markdown("## 🧠 Chọn model để dự đoán") | |
| model_name = st.radio( | |
| "", | |
| [ | |
| "Technical Analysis", | |
| "SARIMA", | |
| "PatchTST", | |
| "TimeFM", | |
| "TimeGPT (Recommended)", | |
| "TimeGPT with news (Beta)" | |
| ], | |
| horizontal=True | |
| ) | |
| # 6. Load Models | |
| def load_timefm(): | |
| model = TimesFmModelForPrediction.from_pretrained( | |
| "google/timesfm-2.0-500m-pytorch", | |
| dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| return model | |
| def load_timegpt(): | |
| return NixtlaClient( | |
| api_key="nixak-zWQjbVl9QCc6eIFL3DDbaBXi09bnPKsa5jdUU7Q8izPpn3eYl0rZPWLLs8NI597PT0VzIODhPUmKzkMc" | |
| ) | |
| def load_patchtst(): | |
| config = PatchTSTConfig( | |
| context_length=31, | |
| prediction_length=3, | |
| input_size=1, | |
| patch_len=6, | |
| stride=3, | |
| d_model=96, | |
| num_hidden_layers=4, | |
| num_attention_heads=4, | |
| dropout=0.15 | |
| ) | |
| model = PatchTSTForPrediction(config) | |
| model.load_state_dict( | |
| torch.load("best_patchtst_vn30.pt", map_location="cpu") | |
| ) | |
| model.eval() | |
| scalers = joblib.load("patchtst_scalers_vn30.pkl") | |
| return model, scalers | |
| # 7. Forecast Function | |
| def compute_ci(df_sym, preds): | |
| returns = df_sym["close"].diff().dropna() | |
| vol = returns.rolling(20, min_periods=5).std().iloc[-1] | |
| z = 1.96 | |
| return preds - z * vol, preds + z * vol | |
| def forecast_timefm(symbol, horizon): | |
| model = load_timefm() | |
| df_sym = df_train[df_train["symbol"] == symbol] | |
| series = df_sym["close"].astype(float).values | |
| past_tensor = ( | |
| torch.from_numpy(series) | |
| .to(dtype=torch.bfloat16, device=model.device) | |
| .unsqueeze(0) | |
| ) | |
| freq_tensor = torch.tensor([0], dtype=torch.long, device=model.device) | |
| with torch.no_grad(): | |
| outputs = model( | |
| past_values=past_tensor, | |
| freq=freq_tensor, | |
| return_dict=True | |
| ) | |
| preds = outputs.mean_predictions[0].float().cpu().numpy()[:horizon] | |
| future_time = pd.bdate_range( | |
| start=df_sym["time"].iloc[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| lo, hi = compute_ci(df_sym, preds) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": preds, | |
| "lower_95": lo, | |
| "upper_95": hi | |
| }) | |
| def forecast_timegpt(symbol, horizon): | |
| client = load_timegpt() | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.rename(columns={"time": "ds", "close": "y"})[["ds", "y"]] | |
| df_sym = df_sym.sort_values("ds").drop_duplicates(subset="ds") | |
| # Create full business-day range | |
| full_range = pd.date_range( | |
| start=df_sym["ds"].min(), | |
| end=df_sym["ds"].max(), | |
| freq="B" | |
| ) | |
| df_sym = ( | |
| df_sym | |
| .set_index("ds") | |
| .reindex(full_range) | |
| .rename_axis("ds") | |
| .reset_index() | |
| ) | |
| # Fill missing prices | |
| df_sym["y"] = df_sym["y"].ffill().bfill() | |
| df_ts = df_sym.rename( | |
| columns={"time": "ds", "close": "y"} | |
| )[["ds", "y"]] | |
| df_ts["unique_id"] = symbol | |
| df_fc = client.forecast( | |
| df=df_ts, | |
| h=horizon, | |
| freq="B", | |
| level=[95] | |
| ) | |
| return df_fc.rename(columns={ | |
| "ds": "time", | |
| "TimeGPT": "forecast_close", | |
| "TimeGPT-lo-95": "lower_95", | |
| "TimeGPT-hi-95": "upper_95" | |
| })[["time", "forecast_close", "lower_95", "upper_95"]] | |
| def forecast_timegpt_news(symbol, horizon): | |
| client = load_timegpt() | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.rename(columns={"time": "ds", "close": "y"})[["ds", "y"]] | |
| df_sym = df_sym.sort_values("ds").drop_duplicates(subset="ds") | |
| # Create full business-day range | |
| full_range = pd.date_range( | |
| start=df_sym["ds"].min(), | |
| end=df_sym["ds"].max(), | |
| freq="B" | |
| ) | |
| df_sym = ( | |
| df_sym | |
| .set_index("ds") | |
| .reindex(full_range) | |
| .rename_axis("ds") | |
| .reset_index() | |
| ) | |
| # Fill missing prices | |
| df_sym["y"] = df_sym["y"].ffill().bfill() | |
| # News | |
| df_news_sym = df_news[df_news["symbol"] == symbol].copy() | |
| df_news_sym["date"] = pd.to_datetime(df_news_sym["date"]) | |
| df_sent = ( | |
| df_news_sym | |
| .groupby("date")["sentiment"] | |
| .mean() | |
| .reset_index() | |
| ) | |
| df_sym["date"] = df_sym["ds"].dt.normalize() | |
| df_ts = df_sym.merge(df_sent, on="date", how="left") | |
| # Fill missing sentiment | |
| df_ts["sentiment"] = df_ts["sentiment"].fillna(0.0) | |
| df_ts["unique_id"] = symbol | |
| df_fc = client.forecast( | |
| df=df_ts[["unique_id", "ds", "y", "sentiment"]], | |
| h=horizon, | |
| freq="B", | |
| level=[95] | |
| ) | |
| return df_fc.rename(columns={ | |
| "ds": "time", | |
| "TimeGPT": "forecast_close", | |
| "TimeGPT-lo-95": "lower_95", | |
| "TimeGPT-hi-95": "upper_95" | |
| })[["time", "forecast_close", "lower_95", "upper_95"]] | |
| def forecast_patchtst(symbol, horizon): | |
| model, scalers = load_patchtst() | |
| df_sym = df_train[df_train["symbol"] == symbol] | |
| scaler = scalers[symbol] | |
| values = df_sym["close"].values.reshape(-1, 1) | |
| scaled = scaler.transform(values).flatten() | |
| context = scaled[-31:].copy() | |
| preds_scaled = [] | |
| with torch.no_grad(): | |
| for _ in range(horizon): | |
| xb = torch.tensor(context, dtype=torch.float32).unsqueeze(0).unsqueeze(-1) | |
| out = model(past_values=xb) | |
| next_pred = out.prediction_outputs[0, 0, 0].item() | |
| preds_scaled.append(next_pred) | |
| context = np.roll(context, -1) | |
| context[-1] = next_pred | |
| preds = scaler.inverse_transform( | |
| np.array(preds_scaled).reshape(-1, 1) | |
| ).flatten() | |
| future_time = pd.bdate_range( | |
| start=df_sym["time"].iloc[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| lo, hi = compute_ci(df_sym, preds) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": preds, | |
| "lower_95": lo, | |
| "upper_95": hi | |
| }) | |
| def forecast_sarima(symbol, horizon=7): | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.sort_values("time").drop_duplicates(subset="time") | |
| # Endogenous (price) | |
| ts = ( | |
| df_sym | |
| .set_index("time")["close"] | |
| .astype(float) | |
| .asfreq("B") | |
| .ffill() | |
| ) | |
| # Exogenous variables | |
| df_exog = df_sym.set_index("time").asfreq("B").ffill() | |
| df_exog["return"] = df_exog["close"].pct_change().shift(1) | |
| df_exog["volume"] = np.log1p(df_exog["volume"]) | |
| exog = df_exog[["return", "volume"]].fillna(0) | |
| # Đồng bộ index | |
| exog = exog.loc[ts.index] | |
| # Model | |
| model = SARIMAX( | |
| ts, | |
| exog=exog, | |
| order=(2, 0, 2), | |
| seasonal_order=(1, 0, 1, 5), | |
| enforce_stationarity=False, | |
| enforce_invertibility=False | |
| ) | |
| result = model.fit(disp=False, maxiter=50) | |
| # Future exog | |
| future_time = pd.bdate_range( | |
| start=ts.index[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| mean_ret = df_exog["return"].tail(20).mean() | |
| mean_vol = df_exog["volume"].tail(20).mean() | |
| future_exog = pd.DataFrame( | |
| { | |
| "return": np.full(horizon, mean_ret), | |
| "volume": np.full(horizon, mean_vol) | |
| }, | |
| index=future_time | |
| ) | |
| fc = result.get_forecast( | |
| steps=horizon, | |
| exog=future_exog | |
| ) | |
| conf_int = fc.conf_int(alpha=0.05) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": fc.predicted_mean.values, | |
| "lower_95": conf_int.iloc[:, 0].values, | |
| "upper_95": conf_int.iloc[:, 1].values | |
| }) | |
| def forecast_technical(symbol, horizon=7): | |
| df = df_train[df_train["symbol"] == symbol].copy() | |
| df = df.sort_values("time").reset_index(drop=True) | |
| # Indicators | |
| df["MA20"] = df["close"].rolling(20).mean() | |
| df["MA50"] = df["close"].rolling(50).mean() | |
| delta = df["close"].diff() | |
| gain = delta.clip(lower=0) | |
| loss = -delta.clip(upper=0) | |
| avg_gain = gain.rolling(14).mean() | |
| avg_loss = loss.rolling(14).mean() | |
| rs = avg_gain / avg_loss | |
| df["RSI"] = 100 - (100 / (1 + rs)) | |
| def signal(row): | |
| if row["close"] > row["MA20"] > row["MA50"] and row["RSI"] > 50: | |
| return 1 | |
| elif row["close"] < row["MA20"] < row["MA50"] and row["RSI"] < 50: | |
| return -1 | |
| else: | |
| return 0 | |
| df["signal"] = df.apply(signal, axis=1) | |
| # Forecast | |
| last_price = df["close"].iloc[-1] | |
| last_signal = df["signal"].iloc[-1] | |
| returns = df["close"].pct_change().rolling(10).std() | |
| vol = returns.iloc[-1] | |
| vol = 0 if np.isnan(vol) else min(vol, 0.01) | |
| future_time = pd.bdate_range( | |
| start=df["time"].iloc[-1] + pd.offsets.BDay(), | |
| periods=horizon | |
| ) | |
| prices = [] | |
| price = last_price | |
| for _ in range(horizon): | |
| if last_signal == 1: | |
| price *= (1 + vol) | |
| elif last_signal == -1: | |
| price *= (1 - vol) | |
| else: | |
| price *= (1 + 0.1 * vol) | |
| prices.append(price) | |
| preds = np.array(prices) | |
| lo, hi = compute_ci(df, preds) | |
| return pd.DataFrame({ | |
| "time": future_time, | |
| "forecast_close": preds, | |
| "lower_95": lo, | |
| "upper_95": hi | |
| }) | |
| # 8. Rolling backtest | |
| def rolling_backtest(df_sym, forecast_fn, window=500, step=5): | |
| errors = [] | |
| for start in range(0, len(df_sym) - window - horizon, step): | |
| train_slice = df_sym.iloc[start:start + window] | |
| test_slice = df_sym.iloc[start + window:start + window + horizon] | |
| df_pred = forecast_fn( | |
| train_slice["symbol"].iloc[0], | |
| horizon | |
| ) | |
| y_true = test_slice["close"].values[:len(df_pred)] | |
| y_pred = df_pred["forecast_close"].values | |
| # MAE | |
| mae = mean_absolute_error(y_true, y_pred) | |
| # RMSE | |
| rmse = root_mean_squared_error(y_true, y_pred) | |
| # MASE (naive = yesterday) | |
| naive = train_slice["close"].diff().dropna() | |
| mae_naive = naive.abs().mean() | |
| mase = mae / mae_naive if mae_naive != 0 else np.nan | |
| # Directional Accuracy | |
| actual_dir = np.sign(np.diff(y_true)) | |
| pred_dir = np.sign(np.diff(y_pred)) | |
| da = (actual_dir == pred_dir).mean() if len(actual_dir) > 0 else np.nan | |
| errors.append({ | |
| "MAE": mae, | |
| "RMSE": rmse, | |
| "MASE": mase, | |
| "DA": da | |
| }) | |
| return pd.DataFrame(errors) | |
| def backtest_sarima(symbol, n_test=7): | |
| df_sym = df_train[df_train["symbol"] == symbol].copy() | |
| df_sym = df_sym.sort_values("time").drop_duplicates(subset="time") | |
| # Endogenous | |
| ts = ( | |
| df_sym | |
| .set_index("time")["close"] | |
| .astype(float) | |
| .asfreq("B") | |
| .ffill() | |
| ) | |
| # Exogenous | |
| df_exog = df_sym.set_index("time").asfreq("B").ffill() | |
| df_exog["return"] = df_exog["close"].pct_change().shift(1) | |
| df_exog["volume"] = np.log1p(df_exog["volume"]) | |
| exog = df_exog[["return", "volume"]].fillna(0) | |
| # Đồng bộ index | |
| exog = exog.loc[ts.index] | |
| # Train / Test split | |
| ts_train = ts.iloc[:-n_test] | |
| ts_test = ts.iloc[-n_test:] | |
| exog_train = exog.loc[ts_train.index] | |
| exog_test = exog.loc[ts_test.index] | |
| # Model | |
| model = SARIMAX( | |
| ts_train, | |
| exog=exog_train, | |
| order=(2, 0, 2), | |
| seasonal_order=(1, 0, 1, 5), | |
| enforce_stationarity=False, | |
| enforce_invertibility=False | |
| ) | |
| result = model.fit(disp=False, maxiter=50) | |
| fc = result.get_forecast( | |
| steps=n_test, | |
| exog=exog_test | |
| ) | |
| y_pred = fc.predicted_mean | |
| y_true = ts_test | |
| # Metrics | |
| mae = mean_absolute_error(y_true, y_pred) | |
| rmse = root_mean_squared_error(y_true, y_pred) | |
| naive = ts_train.shift(1).dropna() | |
| mae_naive = mean_absolute_error( | |
| ts_train.loc[naive.index], | |
| naive | |
| ) | |
| mase = mae / mae_naive if mae_naive != 0 else np.nan | |
| actual_dir = np.sign(y_true.diff().iloc[1:]) | |
| pred_dir = np.sign(y_pred.diff().iloc[1:]) | |
| da = (actual_dir == pred_dir).mean() | |
| return { | |
| "MAE": mae, | |
| "RMSE": rmse, | |
| "MASE": mase, | |
| "DA": da | |
| } | |
| def backtest_technical(symbol, n_test=7): | |
| df = df_train[df_train["symbol"] == symbol].copy() | |
| df = df.sort_values("time").reset_index(drop=True) | |
| train = df.iloc[:-n_test] | |
| test = df.iloc[-n_test:] | |
| df_fc = forecast_technical(symbol, horizon=n_test) | |
| y_true = test["close"].values | |
| y_pred = df_fc["forecast_close"].values | |
| # MAE | |
| mae = mean_absolute_error(y_true, y_pred) | |
| # RMSE | |
| rmse = root_mean_squared_error(y_true, y_pred) | |
| # MASE | |
| naive = train["close"].shift(1).dropna() | |
| mae_naive = np.abs(naive.values - train["close"].iloc[1:].values).mean() | |
| mase = mae / mae_naive if mae_naive != 0 else np.nan | |
| # Directional Accuracy | |
| actual_dir = np.sign(np.diff(y_true)) | |
| pred_dir = np.sign(np.diff(y_pred)) | |
| da = (actual_dir == pred_dir).mean() | |
| return { | |
| "MAE": mae, | |
| "RMSE": rmse, | |
| "MASE": mase, | |
| "DA": da | |
| } | |
| # 7. Run Forecast | |
| if st.button("🚀 Forecast"): | |
| status = st.empty() | |
| status.info(f"⏳ Đang chạy model: **{model_name}**") | |
| # Forecast theo model được chọn | |
| if model_name == "TimeFM": | |
| df_fc = forecast_timefm(symbol, horizon) | |
| bt_fn = forecast_timefm | |
| bt_err = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| elif model_name == "TimeGPT (Recommended)": | |
| df_fc = forecast_timegpt(symbol, horizon) | |
| bt_fn = forecast_timegpt | |
| bt_err = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| elif model_name == "TimeGPT with news (Beta)": | |
| df_fc = forecast_timegpt_news(symbol, horizon) | |
| bt_fn = forecast_timegpt_news | |
| bt_err = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| elif model_name == "SARIMA": | |
| df_fc = forecast_sarima(symbol, horizon) | |
| bt_err = pd.DataFrame([backtest_sarima(symbol, n_test=horizon)]) | |
| elif model_name == "Technical Analysis": | |
| df_fc = forecast_technical(symbol, horizon) | |
| bt_err = pd.DataFrame([backtest_technical(symbol, n_test=horizon)]) | |
| else: # PatchTST | |
| df_fc = forecast_patchtst(symbol, horizon) | |
| bt_fn = forecast_patchtst | |
| bt_err = rolling_backtest( | |
| df_train[df_train["symbol"] == symbol].reset_index(drop=True), | |
| bt_fn | |
| ) | |
| # Lưu vào session_state | |
| st.session_state.df_fc = df_fc | |
| st.session_state.bt_err = bt_err | |
| status.success("Forecast hoàn tất!") | |
| # 9. Table | |
| if "df_fc" in st.session_state: | |
| st.subheader(f"📋 Forecast Table {symbol} – {horizon} ngày ({model_name})") | |
| df_display = ( | |
| st.session_state.df_fc | |
| .rename(columns={ | |
| "time": "Time", | |
| "forecast_close": "Forecast Close Price", | |
| "lower_95": "Lower 95% CI", | |
| "upper_95": "Upper 95% CI" | |
| }) | |
| .assign( | |
| **{ | |
| "Forecast Close Price": lambda x: x["Forecast Close Price"].round(2), | |
| "Lower 95% CI": lambda x: x["Lower 95% CI"].round(2), | |
| "Upper 95% CI": lambda x: x["Upper 95% CI"].round(2), | |
| } | |
| ) | |
| ) | |
| st.dataframe(df_display, use_container_width=True, height=min(len(df_display), 30) * 35 + 40) | |
| # 10. Plot | |
| if "df_fc" in st.session_state: | |
| df_fc = st.session_state.df_fc | |
| df_hist = df_train[df_train["symbol"] == symbol] | |
| st.subheader(f"📊 Forecast plot {symbol} – {horizon} ngày ({model_name})") | |
| fig = go.Figure() | |
| # Historical | |
| fig.add_trace(go.Scatter( | |
| x=df_hist["time"], | |
| y=df_hist["close"], | |
| mode="lines", | |
| name="Historical", | |
| line=dict(width=2.5) | |
| )) | |
| # CI | |
| fig.add_trace(go.Scatter( | |
| x=df_fc["time"], | |
| y=df_fc["upper_95"], | |
| mode="lines", | |
| line=dict(width=0), | |
| name="95% CI", | |
| showlegend=True | |
| ) | |
| ) | |
| fig.add_trace(go.Scatter( | |
| x=df_fc["time"], | |
| y=df_fc["lower_95"], | |
| mode="lines", | |
| line=dict(width=0), | |
| fill="tonexty", | |
| fillcolor="rgba(168, 85, 247, 0.35)", # tím xịn | |
| name=None, | |
| showlegend=False | |
| ) | |
| ) | |
| # Forecast | |
| fig.add_trace(go.Scatter( | |
| x=df_fc["time"], | |
| y=df_fc["forecast_close"], | |
| mode="lines+markers", | |
| name=f"Forecast ({model_name})" | |
| )) | |
| fig.update_layout( | |
| template="plotly_dark", | |
| hovermode="x unified", | |
| dragmode="pan", | |
| ) | |
| st.plotly_chart( | |
| fig, | |
| use_container_width=True, | |
| config={"scrollZoom": True} | |
| ) | |
| # 11. Metrics | |
| if "bt_err" in st.session_state: | |
| st.subheader("📐 Metrics") | |
| c1, c2, c3, c4 = st.columns(4) | |
| with c1: | |
| st.metric( | |
| "MAE (avg)", | |
| f"{st.session_state.bt_err['MAE'].mean():.3f}" | |
| ) | |
| with c2: | |
| st.metric( | |
| "RMSE (avg)", | |
| f"{st.session_state.bt_err['RMSE'].mean():.3f}" | |
| ) | |
| with c3: | |
| st.metric( | |
| "MASE (avg)", | |
| f"{st.session_state.bt_err['MASE'].mean():.3f}" | |
| ) | |
| with c4: | |
| st.metric( | |
| "DA (avg)", | |
| f"{st.session_state.bt_err['DA'].mean() * 100:.1f}%" | |
| ) | |
| if menu == "📘 Tài liệu": | |
| # 1. Page Config | |
| st.set_page_config( | |
| page_title="VN30 Forecast - Documentation", | |
| layout="wide" | |
| ) | |
| # 2. UI | |
| st.markdown('<h1 style="text-align: center;"><span class="emoji">📈</span> VN30 Stock Forecast - Documentation</h1>', unsafe_allow_html=True) | |
| a1, a2, a3 = st.columns([1, 2, 1]) | |
| with a2: | |
| st.image( | |
| "VN30_Doc_Thumb.png", | |
| width=1000 | |
| ) | |
| st.markdown("---") | |
| # 3. Tổng quan hệ thống | |
| st.markdown("## 🔍 Tổng quan hệ thống") | |
| st.markdown( | |
| """ | |
| Ứng dụng **VN30 Stock Forecast** được xây dựng nhằm mục tiêu dự báo **giá đóng cửa (Close price)** | |
| của các cổ phiếu thuộc rổ **VN30** theo tần suất **ngày giao dịch (Business Day)**. | |
| Hệ thống kết hợp nhiều phương pháp dự báo khác nhau, từ: | |
| - Phương pháp truyền thống | |
| - Mô hình thống kê | |
| - Deep Learning | |
| - Foundation & Generative Models | |
| Pipeline tổng thể gồm **3 bước chính**: | |
| 1. Thu thập dữ liệu (Crawl data) | |
| 2. Dự báo bằng nhiều mô hình | |
| 3. Đánh giá và so sánh mô hình | |
| """ | |
| ) | |
| # 4. Crawl data | |
| st.markdown("## 📥 Thu thập dữ liệu (Crawl Data)") | |
| st.markdown( | |
| """ | |
| ### 🔹 Nguồn dữ liệu | |
| Dữ liệu giá cổ phiếu được thu thập thông qua thư viện **`vnstock`**, | |
| sử dụng API từ các công ty chứng khoán tại Việt Nam (VCI). | |
| Các mã cổ phiếu thuộc rổ **VN30** được crawl với các trường: | |
| - `open`, `high`, `low`, `close` (giá mở cửa, cao nhất, thấp nhất, đóng cửa) | |
| - `volume` (khối lượng giao dịch) | |
| - `time` (ngày giao dịch) | |
| **Tần suất dữ liệu:** theo ngày | |
| **Khoảng thời gian:** từ năm 2022 đến hiện tại | |
| """ | |
| ) | |
| st.markdown( | |
| """ | |
| ### 🔹 Xử lý dữ liệu | |
| - Chuẩn hóa định dạng thời gian | |
| - Sắp xếp theo `symbol` và `time` | |
| - Đồng bộ theo ngày giao dịch (Business Day) | |
| - Forward fill / Backward fill cho các ngày thiếu dữ liệu | |
| """ | |
| ) | |
| st.markdown( | |
| """ | |
| ### 🔹 Tin tức & Sentiment (tuỳ chọn) | |
| - Tin tức được crawl từ **cafef.vn** | |
| - Mỗi bài viết được gán **sentiment score** | |
| - Sentiment được tổng hợp theo ngày | |
| - Sentiment được sử dụng như **biến ngoại sinh** cho mô hình TimeGPT with news | |
| """ | |
| ) | |
| # 3. Forecasting models | |
| st.markdown("## 🧠 Các mô hình dự báo") | |
| st.markdown("### 1. Technical Analysis") | |
| st.markdown( | |
| """ | |
| Mô hình dựa trên các chỉ báo kỹ thuật phổ biến: | |
| - Moving Average (MA20, MA50) | |
| - RSI (Relative Strength Index) | |
| Ý tưởng chính: | |
| - Xác định xu hướng giá | |
| - Xác định động lượng thị trường | |
| - Sinh tín hiệu mua / bán / giữ | |
| 👉 Đây là mô hình **baseline**, đơn giản và dễ diễn giải. | |
| """ | |
| ) | |
| st.markdown("### 2. SARIMA") | |
| st.markdown( | |
| """ | |
| SARIMA (Seasonal ARIMA) là mô hình thống kê cổ điển cho chuỗi thời gian. | |
| Đặc điểm: | |
| - Mô hình hóa xu hướng | |
| - Mô hình hóa mùa vụ | |
| - Kết hợp biến ngoại sinh | |
| Trong hệ thống: | |
| - Endogenous: giá đóng cửa | |
| - Exogenous: return, volume (log-transform) | |
| 👉 Phù hợp với dữ liệu tài chính có tính mùa vụ. | |
| """ | |
| ) | |
| st.markdown("### 3. PatchTST") | |
| st.markdown( | |
| """ | |
| PatchTST là mô hình Transformer cho chuỗi thời gian. | |
| Ý tưởng chính: | |
| - Chia chuỗi thời gian thành các patch | |
| - Giảm độ dài sequence | |
| - Học quan hệ dài hạn hiệu quả hơn | |
| 👉 Phù hợp cho dự báo nhiều bước (multi-step forecasting). | |
| """ | |
| ) | |
| st.markdown("### 4. TimeFM") | |
| st.markdown( | |
| """ | |
| TimeFM là **foundation model cho time series**, được huấn luyện trước trên tập dữ liệu lớn. | |
| Đặc điểm: | |
| - Không cần huấn luyện lại nhiều | |
| - Dự báo nhanh | |
| - Tổng quát hóa tốt | |
| 👉 Trong project, TimeFM được dùng ở mức **tham khảo**. | |
| """ | |
| ) | |
| st.markdown("### 5. TimeGPT (Recommended)") | |
| st.markdown( | |
| """ | |
| TimeGPT là **Generative Forecasting Model** của Nixtla. | |
| Đặc điểm: | |
| - Tự động học trend và seasonality | |
| - Sinh dự báo xác suất | |
| - Trả về khoảng tin cậy (Confidence Interval) | |
| 👉 Đây là **mô hình được khuyến nghị sử dụng** trong ứng dụng. | |
| """ | |
| ) | |
| st.markdown("### 6. TimeGPT with News (Beta)") | |
| st.markdown( | |
| """ | |
| Mở rộng của TimeGPT bằng cách thêm biến **sentiment từ tin tức** từ model LLM **qwen3**. | |
| Ý tưởng: | |
| - Giá cổ phiếu chịu ảnh hưởng bởi thông tin thị trường | |
| - Sentiment được đưa vào như biến ngoại sinh | |
| 👉 Hiện tại ở mức **Beta** vì Qwen3 khá nặng chỉ chạy trên máy local, nên chỉ mang tính thử nghiệm. | |
| """ | |
| ) | |
| # 4. Metrics | |
| st.markdown("## 📐 Các chỉ số đánh giá (Metrics)") | |
| st.markdown("### 🔹 MAE – Mean Absolute Error") | |
| st.markdown( | |
| """ | |
| Đo sai số tuyệt đối trung bình giữa giá thực tế và giá dự báo. | |
| 👉 **Giá trị càng nhỏ → model càng tốt** | |
| """ | |
| ) | |
| st.markdown("### 🔹 RMSE – Root Mean Squared Error") | |
| st.markdown( | |
| """ | |
| Đo sai số bình phương trung bình, nhạy cảm với các lỗi lớn. | |
| 👉 **Giá trị càng nhỏ → model càng tốt** | |
| """ | |
| ) | |
| st.markdown("### 🔹 MASE – Mean Absolute Scaled Error") | |
| st.markdown( | |
| """ | |
| So sánh mô hình với baseline naive (giá hôm nay = giá hôm qua). | |
| 👉 MASE < 1: model tốt hơn naive | |
| 👉 **Giá trị càng nhỏ → model càng tốt** | |
| """ | |
| ) | |
| st.markdown("### 🔹 DA – Directional Accuracy") | |
| st.markdown( | |
| """ | |
| Đo độ chính xác trong việc dự đoán **chiều hướng tăng / giảm** của giá. | |
| 👉 Giá trị từ 0 đến 1 | |
| 👉 **DA > 0.5: model dự đoán xu hướng tốt** | |
| """ | |
| ) |