Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import pytorch_lightning as pl | |
| from neuralforecast.core import NeuralForecast | |
| from neuralforecast.models import NHITS, TimesNet, LSTM, TFT | |
| from neuralforecast.losses.pytorch import HuberMQLoss | |
| from neuralforecast.utils import AirPassengersDF | |
| import time | |
| from st_aggrid import AgGrid | |
| from nixtla import NixtlaClient | |
| import os | |
| st.set_page_config(layout='wide') | |
| def load_model(path, freq): | |
| nf = NeuralForecast.load(path=path) | |
| return nf | |
| def load_all_models(): | |
| nhits_paths = { | |
| 'D': './M4/NHITS/daily', | |
| 'M': './M4/NHITS/monthly', | |
| 'H': './M4/NHITS/hourly', | |
| 'W': './M4/NHITS/weekly', | |
| 'Y': './M4/NHITS/yearly' | |
| } | |
| timesnet_paths = { | |
| 'D': './M4/TimesNet/daily', | |
| 'M': './M4/TimesNet/monthly', | |
| 'H': './M4/TimesNet/hourly', | |
| 'W': './M4/TimesNet/weekly', | |
| 'Y': './M4/TimesNet/yearly' | |
| } | |
| lstm_paths = { | |
| 'D': './M4/LSTM/daily', | |
| 'M': './M4/LSTM/monthly', | |
| 'H': './M4/LSTM/hourly', | |
| 'W': './M4/LSTM/weekly', | |
| 'Y': './M4/LSTM/yearly' | |
| } | |
| tft_paths = { | |
| 'D': './M4/TFT/daily', | |
| 'M': './M4/TFT/monthly', | |
| 'H': './M4/TFT/hourly', | |
| 'W': './M4/TFT/weekly', | |
| 'Y': './M4/TFT/yearly' | |
| } | |
| nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()} | |
| timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()} | |
| lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()} | |
| tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()} | |
| return nhits_models, timesnet_models, lstm_models, tft_models | |
| def generate_forecast(model, df,tag=False): | |
| if tag == 'retrain': | |
| forecast_df = model.predict() | |
| else: | |
| forecast_df = model.predict(df=df) | |
| return forecast_df | |
| def determine_frequency(df): | |
| df['ds'] = pd.to_datetime(df['ds']) | |
| df = df.drop_duplicates(subset='ds') | |
| df = df.set_index('ds') | |
| # # Create a complete date range | |
| # full_range = pd.date_range(start=df.index.min(), end=df.index.max(),freq=freq) | |
| # # Reindex the DataFrame to this full date range | |
| # df_full = df.reindex(full_range) | |
| # Infer the frequency | |
| # freq = pd.infer_freq(df_full.index) | |
| freq = pd.infer_freq(df.index) | |
| if not freq: | |
| st.warning('The forecast will use default Daily forecast due to date inconsistency. Please check your data.',icon="⚠️") | |
| freq = 'D' | |
| return freq | |
| import plotly.graph_objects as go | |
| def plot_forecasts(forecast_df, train_df, title): | |
| # Combine historical and forecast data | |
| plot_df = pd.concat([train_df, forecast_df]).set_index('ds') | |
| # Find relevant columns | |
| historical_col = 'y' | |
| forecast_col = next((col for col in plot_df.columns if 'median' in col), None) | |
| lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None) | |
| hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None) | |
| if forecast_col is None: | |
| raise KeyError("No forecast column found in the data.") | |
| # Create Plotly figure | |
| fig = go.Figure() | |
| # Add historical data | |
| fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[historical_col], mode='lines', name='Historical')) | |
| # Add forecast data | |
| fig.add_trace(go.Scatter(x=plot_df.index, y=plot_df[forecast_col], mode='lines', name='Forecast')) | |
| # Add confidence interval if available | |
| if lo_col and hi_col: | |
| fig.add_trace(go.Scatter( | |
| x=plot_df.index, | |
| y=plot_df[hi_col], | |
| mode='lines', | |
| line=dict(color='rgba(0,100,80,0.2)'), | |
| showlegend=False | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=plot_df.index, | |
| y=plot_df[lo_col], | |
| mode='lines', | |
| line=dict(color='rgba(0,100,80,0.2)'), | |
| fill='tonexty', | |
| fillcolor='rgba(0,100,80,0.2)', | |
| name='90% Confidence Interval' | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title='Timestamp [t]', | |
| yaxis_title='Value', | |
| template='plotly_white' | |
| ) | |
| # Display the plot | |
| st.plotly_chart(fig) | |
| def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models): | |
| if freq == 'D': | |
| return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D'] | |
| elif freq == 'ME': | |
| return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M'] | |
| elif freq == 'H': | |
| return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H'] | |
| elif freq in ['W', 'W-SUN']: | |
| return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W'] | |
| elif freq in ['Y', 'Y-DEC']: | |
| return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y'] | |
| else: | |
| raise ValueError(f"Unsupported frequency: {freq}") | |
| def load_default(): | |
| df = AirPassengersDF.copy() | |
| return df | |
| def transfer_learning_forecasting(): | |
| st.title("Zero-shot Forecasting") | |
| st.markdown(""" | |
| Instant time series forecasting and visualization by using various pre-trained deep neural network-based model trained on M4 data. | |
| """) | |
| nhits_models, timesnet_models, lstm_models, tft_models = load_all_models() | |
| with st.sidebar.expander("Upload and Configure Dataset", expanded=True): | |
| if 'uploaded_file' not in st.session_state: | |
| uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
| if uploaded_file: | |
| df = pd.read_csv(uploaded_file) | |
| st.session_state.df = df | |
| st.session_state.uploaded_file = uploaded_file | |
| else: | |
| df = load_default() | |
| st.session_state.df = df | |
| else: | |
| if st.checkbox("Upload a new file (CSV)"): | |
| uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) | |
| if uploaded_file: | |
| df = pd.read_csv(uploaded_file) | |
| st.session_state.df = df | |
| st.session_state.uploaded_file = uploaded_file | |
| else: | |
| df = st.session_state.df | |
| else: | |
| df = st.session_state.df | |
| columns = df.columns.tolist() | |
| ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0) | |
| target_columns = [col for col in columns if (col != ds_col) and (col != 'unique_id')] | |
| y_col = st.selectbox("Select Target column", options=target_columns, index=0) | |
| st.session_state.ds_col = ds_col | |
| st.session_state.y_col = y_col | |
| # Model selection and forecasting | |
| st.sidebar.subheader("Model Selection and Forecasting") | |
| model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"]) | |
| horizon = st.sidebar.number_input("Forecast horizon", value=12) | |
| df = df.rename(columns={ds_col: 'ds', y_col: 'y'}) | |
| df['unique_id']=1 | |
| df = df[['unique_id','ds','y']] | |
| # Determine frequency of data | |
| frequency = determine_frequency(df) | |
| st.sidebar.write(f"Detected frequency: {frequency}") | |
| nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models) | |
| forecast_results = {} | |
| if st.sidebar.button("Submit"): | |
| start_time = time.time() # Start timing | |
| if model_choice == "NHITS": | |
| forecast_results['NHITS'] = generate_forecast(nhits_model, df) | |
| elif model_choice == "TimesNet": | |
| forecast_results['TimesNet'] = generate_forecast(timesnet_model, df) | |
| elif model_choice == "LSTM": | |
| forecast_results['LSTM'] = generate_forecast(lstm_model, df) | |
| elif model_choice == "TFT": | |
| forecast_results['TFT'] = generate_forecast(tft_model, df) | |
| st.session_state.forecast_results = forecast_results | |
| for model_name, forecast_df in forecast_results.items(): | |
| plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}') | |
| end_time = time.time() # End timing | |
| time_taken = end_time - start_time | |
| st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds") | |
| if 'forecast_results' in st.session_state: | |
| forecast_results = st.session_state.forecast_results | |
| st.markdown('You can download Input and Forecast Data below') | |
| tab_insample, tab_forecast = st.tabs( | |
| ["Input data", "Forecast"] | |
| ) | |
| with tab_insample: | |
| df_grid = df.drop(columns="unique_id") | |
| st.write(df_grid) | |
| # grid_table = AgGrid( | |
| # df_grid, | |
| # theme="alpine", | |
| # ) | |
| with tab_forecast: | |
| if model_choice in forecast_results: | |
| df_grid = forecast_results[model_choice] | |
| st.write(df_grid) | |
| # grid_table = AgGrid( | |
| # df_grid, | |
| # theme="alpine", | |
| # ) | |
| def personalized_forecasting(): | |
| st.title('Personalized Forecasting') | |
| st.subheader("Coming soon. Stay tuned") | |
| pg = st.navigation({ | |
| "Neuralforecast": [ | |
| # Load pages from functions | |
| st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"), | |
| st.Page(personalized_forecasting, title="Personalized Forecasting", icon=":material/star:") | |
| ], | |
| }) | |
| pg.run() | |