Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,8 @@ from neuralforecast.losses.pytorch import HuberMQLoss
|
|
| 8 |
from neuralforecast.utils import AirPassengersDF
|
| 9 |
import time
|
| 10 |
from st_aggrid import AgGrid
|
|
|
|
|
|
|
| 11 |
|
| 12 |
@st.cache_resource
|
| 13 |
def load_model(path, freq):
|
|
@@ -107,7 +109,7 @@ def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_mo
|
|
| 107 |
else:
|
| 108 |
raise ValueError(f"Unsupported frequency: {freq}")
|
| 109 |
|
| 110 |
-
def select_model(horizon, model_type, max_steps=
|
| 111 |
if model_type == 'NHITS':
|
| 112 |
return NHITS(input_size=5 * horizon,
|
| 113 |
h=horizon,
|
|
@@ -304,17 +306,45 @@ def dynamic_forecasting():
|
|
| 304 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
| 305 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
| 306 |
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
| 307 |
-
dynamic_max_steps = st.sidebar.number_input('Max steps', value=
|
| 308 |
|
| 309 |
if st.sidebar.button("Submit"):
|
| 310 |
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
pg = st.navigation({
|
| 313 |
-
"
|
| 314 |
# Load pages from functions
|
| 315 |
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
|
| 316 |
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
|
| 317 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
})
|
| 319 |
|
| 320 |
try:
|
|
|
|
| 8 |
from neuralforecast.utils import AirPassengersDF
|
| 9 |
import time
|
| 10 |
from st_aggrid import AgGrid
|
| 11 |
+
from nixtla import NixtlaClient
|
| 12 |
+
|
| 13 |
|
| 14 |
@st.cache_resource
|
| 15 |
def load_model(path, freq):
|
|
|
|
| 109 |
else:
|
| 110 |
raise ValueError(f"Unsupported frequency: {freq}")
|
| 111 |
|
| 112 |
+
def select_model(horizon, model_type, max_steps=50):
|
| 113 |
if model_type == 'NHITS':
|
| 114 |
return NHITS(input_size=5 * horizon,
|
| 115 |
h=horizon,
|
|
|
|
| 306 |
st.sidebar.subheader("Dynamic Model Selection and Forecasting")
|
| 307 |
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
|
| 308 |
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
|
| 309 |
+
dynamic_max_steps = st.sidebar.number_input('Max steps', value=10)
|
| 310 |
|
| 311 |
if st.sidebar.button("Submit"):
|
| 312 |
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)
|
| 313 |
|
| 314 |
+
def timegpt():
|
| 315 |
+
st.title("TimeGPT Forecasting")
|
| 316 |
+
with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
|
| 317 |
+
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
|
| 318 |
+
if uploaded_file:
|
| 319 |
+
df = pd.read_csv(uploaded_file)
|
| 320 |
+
st.session_state.df = df
|
| 321 |
+
else:
|
| 322 |
+
df = load_default()
|
| 323 |
+
st.session_state.df = df
|
| 324 |
+
|
| 325 |
+
# Column selection
|
| 326 |
+
columns = df.columns.tolist() # Convert Index to list
|
| 327 |
+
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
|
| 328 |
+
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
pg = st.navigation({
|
| 338 |
+
"NeuralForecast": [
|
| 339 |
# Load pages from functions
|
| 340 |
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
|
| 341 |
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
|
| 342 |
+
],
|
| 343 |
+
"TimeGPT": [
|
| 344 |
+
# Load pages from functions
|
| 345 |
+
st.Page(timegpt, title="TimeGPT Forecast", icon=":material/smart_toy:")
|
| 346 |
+
st.Page(timegpt, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:")
|
| 347 |
+
]
|
| 348 |
})
|
| 349 |
|
| 350 |
try:
|