add distribution type
Browse files
app.py
CHANGED
|
@@ -1,8 +1,14 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
|
|
|
| 3 |
from gluonts.dataset.pandas import PandasDataset
|
| 4 |
from gluonts.dataset.split import split
|
| 5 |
from gluonts.torch.model.deepar import DeepAREstimator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from gluonts.evaluation import Evaluator, make_evaluation_predictions
|
| 7 |
|
| 8 |
from make_plot import plot_forecast, plot_train_test
|
|
@@ -32,6 +38,7 @@ def train_and_forecast(
|
|
| 32 |
prediction_length,
|
| 33 |
rolling_windows,
|
| 34 |
epochs,
|
|
|
|
| 35 |
progress=gr.Progress(track_tqdm=True),
|
| 36 |
):
|
| 37 |
if not input_data:
|
|
@@ -54,7 +61,14 @@ def train_and_forecast(
|
|
| 54 |
|
| 55 |
training_data, test_gen = split(gluon_df, offset=row_offset)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
estimator = DeepAREstimator(
|
|
|
|
| 58 |
prediction_length=prediction_length,
|
| 59 |
freq=gluon_df.freq,
|
| 60 |
trainer_kwargs=dict(max_epochs=epochs),
|
|
@@ -108,6 +122,11 @@ with gr.Blocks() as demo:
|
|
| 108 |
)
|
| 109 |
windows = gr.Number(value=3, label="Number of Windows", precision=0)
|
| 110 |
epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
with gr.Row(label="Dataset"):
|
| 113 |
upload_btn = gr.UploadButton(label="Upload")
|
|
@@ -122,7 +141,7 @@ with gr.Blocks() as demo:
|
|
| 122 |
)
|
| 123 |
train_btn.click(
|
| 124 |
fn=train_and_forecast,
|
| 125 |
-
inputs=[upload_btn, prediction_length, windows, epochs],
|
| 126 |
outputs=[plot, json],
|
| 127 |
)
|
| 128 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
| 3 |
+
|
| 4 |
from gluonts.dataset.pandas import PandasDataset
|
| 5 |
from gluonts.dataset.split import split
|
| 6 |
from gluonts.torch.model.deepar import DeepAREstimator
|
| 7 |
+
from gluonts.torch.distributions import (
|
| 8 |
+
NegativeBinomialOutput,
|
| 9 |
+
StudentTOutput,
|
| 10 |
+
NormalOutput,
|
| 11 |
+
)
|
| 12 |
from gluonts.evaluation import Evaluator, make_evaluation_predictions
|
| 13 |
|
| 14 |
from make_plot import plot_forecast, plot_train_test
|
|
|
|
| 38 |
prediction_length,
|
| 39 |
rolling_windows,
|
| 40 |
epochs,
|
| 41 |
+
distribution,
|
| 42 |
progress=gr.Progress(track_tqdm=True),
|
| 43 |
):
|
| 44 |
if not input_data:
|
|
|
|
| 61 |
|
| 62 |
training_data, test_gen = split(gluon_df, offset=row_offset)
|
| 63 |
|
| 64 |
+
if distribution == "StudentT":
|
| 65 |
+
distr_output = StudentTOutput()
|
| 66 |
+
elif distribution == "Normal":
|
| 67 |
+
distr_output = NormalOutput()
|
| 68 |
+
else:
|
| 69 |
+
distr_output = NegativeBinomialOutput()
|
| 70 |
estimator = DeepAREstimator(
|
| 71 |
+
distr_output=distr_output,
|
| 72 |
prediction_length=prediction_length,
|
| 73 |
freq=gluon_df.freq,
|
| 74 |
trainer_kwargs=dict(max_epochs=epochs),
|
|
|
|
| 122 |
)
|
| 123 |
windows = gr.Number(value=3, label="Number of Windows", precision=0)
|
| 124 |
epochs = gr.Number(value=10, label="Number of Epochs", precision=0)
|
| 125 |
+
distribution = gr.Radio(
|
| 126 |
+
choices=["StudentT", "Negative Binomial", "Normal"],
|
| 127 |
+
value="StudentT",
|
| 128 |
+
label="Distribution",
|
| 129 |
+
)
|
| 130 |
|
| 131 |
with gr.Row(label="Dataset"):
|
| 132 |
upload_btn = gr.UploadButton(label="Upload")
|
|
|
|
| 141 |
)
|
| 142 |
train_btn.click(
|
| 143 |
fn=train_and_forecast,
|
| 144 |
+
inputs=[upload_btn, prediction_length, windows, epochs, distribution],
|
| 145 |
outputs=[plot, json],
|
| 146 |
)
|
| 147 |
|