Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import matplotlib.pyplot as plt
|
|
|
|
| 4 |
from neuralforecast.core import NeuralForecast
|
| 5 |
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
|
| 6 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
|
@@ -9,35 +10,35 @@ import time
|
|
| 9 |
|
| 10 |
# Paths for saving models
|
| 11 |
nhits_paths = {
|
| 12 |
-
'D': './M4/NHITS/daily',
|
| 13 |
-
'M': './M4/NHITS/monthly',
|
| 14 |
-
'H': './M4/NHITS/hourly',
|
| 15 |
-
'W': './M4/NHITS/weekly',
|
| 16 |
-
'Y': './M4/NHITS/yearly'
|
| 17 |
}
|
| 18 |
|
| 19 |
timesnet_paths = {
|
| 20 |
-
'D': './M4/TimesNet/daily',
|
| 21 |
-
'M': './M4/TimesNet/monthly',
|
| 22 |
-
'H': './M4/TimesNet/hourly',
|
| 23 |
-
'W': './M4/TimesNet/weekly',
|
| 24 |
-
'Y': './M4/TimesNet/yearly'
|
| 25 |
}
|
| 26 |
|
| 27 |
lstm_paths = {
|
| 28 |
-
'D': './M4/LSTM/daily',
|
| 29 |
-
'M': './M4/LSTM/monthly',
|
| 30 |
-
'H': './M4/LSTM/hourly',
|
| 31 |
-
'W': './M4/LSTM/weekly',
|
| 32 |
-
'Y': './M4/LSTM/yearly'
|
| 33 |
}
|
| 34 |
|
| 35 |
tft_paths = {
|
| 36 |
-
'D': './M4/TFT/daily',
|
| 37 |
-
'M': './M4/TFT/monthly',
|
| 38 |
-
'H': './M4/TFT/hourly',
|
| 39 |
-
'W': './M4/TFT/weekly',
|
| 40 |
-
'Y': './M4/TFT/yearly'
|
| 41 |
}
|
| 42 |
|
| 43 |
@st.cache_resource
|
|
@@ -164,7 +165,7 @@ def forecast_time_series(df, model_type, freq, horizon, max_steps=200):
|
|
| 164 |
model = select_model(horizon, model_type, max_steps)
|
| 165 |
forecast_results = {}
|
| 166 |
st.write(f"Generating forecast using {model_type} model...")
|
| 167 |
-
forecast_results[model_type] = generate_forecast(model, df
|
| 168 |
|
| 169 |
for model_name, forecast_df in forecast_results.items():
|
| 170 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
|
|
@@ -173,49 +174,82 @@ def forecast_time_series(df, model_type, freq, horizon, max_steps=200):
|
|
| 173 |
time_taken = end_time - start_time
|
| 174 |
st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
| 181 |
-
if uploaded_file:
|
| 182 |
-
df = pd.read_csv(uploaded_file)
|
| 183 |
-
else:
|
| 184 |
-
st.warning("Using default data")
|
| 185 |
-
df = AirPassengersDF.copy()
|
| 186 |
-
|
| 187 |
-
# Model selection and forecasting
|
| 188 |
-
st.subheader("Transfer Learning Forecasting")
|
| 189 |
-
model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
| 190 |
-
horizon = st.slider("Forecast horizon", 1, 100, 10)
|
| 191 |
-
|
| 192 |
-
# Determine frequency of data
|
| 193 |
-
frequency = determine_frequency(df)
|
| 194 |
-
st.write(f"Detected frequency: {frequency}")
|
| 195 |
-
|
| 196 |
-
# Load pre-trained models
|
| 197 |
-
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
| 198 |
-
forecast_results = {}
|
| 199 |
-
|
| 200 |
-
start_time = time.time() # Start timing
|
| 201 |
-
if model_choice == "NHITS":
|
| 202 |
-
forecast_results['NHITS'] = generate_forecast(nhits_model, df)
|
| 203 |
-
elif model_choice == "TimesNet":
|
| 204 |
-
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
|
| 205 |
-
elif model_choice == "LSTM":
|
| 206 |
-
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
|
| 207 |
-
elif model_choice == "TFT":
|
| 208 |
-
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
| 209 |
-
|
| 210 |
-
for model_name, forecast_df in forecast_results.items():
|
| 211 |
-
plot_forecasts(forecast_df, df, f'{model_name} Forecast')
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
from neuralforecast.core import NeuralForecast
|
| 6 |
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
|
| 7 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
|
|
|
| 10 |
|
| 11 |
# Paths for saving models
|
| 12 |
nhits_paths = {
|
| 13 |
+
'D': './results/M4/NHITS/daily',
|
| 14 |
+
'M': './results/M4/NHITS/monthly',
|
| 15 |
+
'H': './results/M4/NHITS/hourly',
|
| 16 |
+
'W': './results/M4/NHITS/weekly',
|
| 17 |
+
'Y': './results/M4/NHITS/yearly'
|
| 18 |
}
|
| 19 |
|
| 20 |
timesnet_paths = {
|
| 21 |
+
'D': './results/M4/TimesNet/daily',
|
| 22 |
+
'M': './results/M4/TimesNet/monthly',
|
| 23 |
+
'H': './results/M4/TimesNet/hourly',
|
| 24 |
+
'W': './results/M4/TimesNet/weekly',
|
| 25 |
+
'Y': './results/M4/TimesNet/yearly'
|
| 26 |
}
|
| 27 |
|
| 28 |
lstm_paths = {
|
| 29 |
+
'D': './results/M4/LSTM/daily',
|
| 30 |
+
'M': './results/M4/LSTM/monthly',
|
| 31 |
+
'H': './results/M4/LSTM/hourly',
|
| 32 |
+
'W': './results/M4/LSTM/weekly',
|
| 33 |
+
'Y': './results/M4/LSTM/yearly'
|
| 34 |
}
|
| 35 |
|
| 36 |
tft_paths = {
|
| 37 |
+
'D': './results/M4/TFT/daily',
|
| 38 |
+
'M': './results/M4/TFT/monthly',
|
| 39 |
+
'H': './results/M4/TFT/hourly',
|
| 40 |
+
'W': './results/M4/TFT/weekly',
|
| 41 |
+
'Y': './results/M4/TFT/yearly'
|
| 42 |
}
|
| 43 |
|
| 44 |
@st.cache_resource
|
|
|
|
| 165 |
model = select_model(horizon, model_type, max_steps)
|
| 166 |
forecast_results = {}
|
| 167 |
st.write(f"Generating forecast using {model_type} model...")
|
| 168 |
+
forecast_results[model_type] = generate_forecast(model, df)
|
| 169 |
|
| 170 |
for model_name, forecast_df in forecast_results.items():
|
| 171 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison')
|
|
|
|
| 174 |
time_taken = end_time - start_time
|
| 175 |
st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")
|
| 176 |
|
| 177 |
+
@st.cache_data
|
| 178 |
+
def load_default():
|
| 179 |
+
df = AirPassengersDf.copy()
|
| 180 |
+
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
+
def transfer_learning_forecasting():
|
| 183 |
+
st.title("Transfer Learning Forecasting")
|
| 184 |
+
|
| 185 |
+
# Upload dataset
|
| 186 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
| 187 |
+
if uploaded_file:
|
| 188 |
+
df = pd.read_csv(uploaded_file)
|
| 189 |
+
else:
|
| 190 |
+
df = load_default()
|
| 191 |
+
|
| 192 |
+
# Model selection and forecasting
|
| 193 |
+
st.subheader("Model Selection and Forecasting")
|
| 194 |
+
model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
|
| 195 |
+
horizon = st.slider("Forecast horizon", 1, 100, 10)
|
| 196 |
+
|
| 197 |
+
# Determine frequency of data
|
| 198 |
+
frequency = determine_frequency(df)
|
| 199 |
+
st.write(f"Detected frequency: {frequency}")
|
| 200 |
+
|
| 201 |
+
# Load pre-trained models
|
| 202 |
+
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
| 203 |
+
forecast_results = {}
|
| 204 |
+
|
| 205 |
+
start_time = time.time() # Start timing
|
| 206 |
+
if model_choice == "NHITS":
|
| 207 |
+
forecast_results['NHITS'] = generate_forecast(nhits_model, df)
|
| 208 |
+
elif model_choice == "TimesNet":
|
| 209 |
+
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
|
| 210 |
+
elif model_choice == "LSTM":
|
| 211 |
+
forecast_results['LSTM'] = generate_forecast(lstm_model, df)
|
| 212 |
+
elif model_choice == "TFT":
|
| 213 |
+
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
| 214 |
+
|
| 215 |
+
for model_name, forecast_df in forecast_results.items():
|
| 216 |
+
plot_forecasts(forecast_df, df, f'{model_name} Forecast')
|
| 217 |
+
|
| 218 |
+
end_time = time.time() # End timing
|
| 219 |
+
time_taken = end_time - start_time
|
| 220 |
+
st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")
|
| 221 |
+
|
| 222 |
+
def dynamic_forecasting():
|
| 223 |
+
st.title("Dynamic Forecasting")
|
| 224 |
+
|
| 225 |
+
# Upload dataset
|
| 226 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
| 227 |
+
if uploaded_file:
|
| 228 |
+
df = pd.read_csv(uploaded_file)
|
| 229 |
+
else:
|
| 230 |
+
df = load_default()
|
| 231 |
+
|
| 232 |
+
# Dynamic forecasting
|
| 233 |
+
st.subheader("Dynamic Model Selection and Forecasting")
|
| 234 |
+
dynamic_model_choice = st.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
| 235 |
+
dynamic_horizon = st.slider("Forecast horizon for dynamic forecasting", 1, 100, 10, key="dynamic_horizon")
|
| 236 |
+
|
| 237 |
+
# Determine frequency of data
|
| 238 |
+
frequency = determine_frequency(df)
|
| 239 |
+
st.write(f"Detected frequency: {frequency}")
|
| 240 |
+
|
| 241 |
+
forecast_time_series(df, dynamic_model_choice, frequency, dynamic_horizon
|
| 242 |
+
|
| 243 |
+
# Define the main navigation
|
| 244 |
+
pg = st.navigation({
|
| 245 |
+
"Overview": [
|
| 246 |
+
# Load pages from functions
|
| 247 |
+
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/library_books:"),
|
| 248 |
+
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/person:"),
|
| 249 |
+
]
|
| 250 |
+
})
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
pg.run()
|
| 254 |
+
except Exception as e:
|
| 255 |
+
st.error(f"Something went wrong: {str(e)}", icon=":material/error:")
|