Spaces:
Runtime error
Runtime error
change conversion
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ from pytorch_forecasting import (
|
|
| 19 |
from PIL import Image
|
| 20 |
|
| 21 |
## Functions
|
| 22 |
-
def raw_preds_to_df(raw,
|
| 23 |
"""
|
| 24 |
raw is output of model.predict with return_index=True
|
| 25 |
quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
|
|
@@ -31,15 +31,13 @@ def raw_preds_to_df(raw, idx_offset, quantiles = None):
|
|
| 31 |
dec_len = raw.output.prediction.shape[1]
|
| 32 |
n_quantiles = preds.shape[-1]
|
| 33 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
| 34 |
-
preds_df = preds_df.assign(
|
| 35 |
-
preds_df = preds_df.assign(
|
| 36 |
-
preds_df = preds_df.assign(
|
| 37 |
if quantiles is not None:
|
| 38 |
-
preds_df['
|
| 39 |
|
| 40 |
-
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['
|
| 41 |
-
preds_df['Date'] = pd.to_datetime(idx_offset)
|
| 42 |
-
preds_df['Date'] = preds_df['Date'] + preds_df['pred_idx'].apply(pd.DateOffset)
|
| 43 |
return preds_df
|
| 44 |
|
| 45 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
|
@@ -57,10 +55,9 @@ def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
|
| 57 |
df = TimeSeriesDataSet.from_parameters(_parameters, df)
|
| 58 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
| 59 |
|
| 60 |
-
def predict(_model, _dataloader
|
| 61 |
-
out =
|
| 62 |
-
|
| 63 |
-
preds = raw_preds_to_df(raw = out, idx_offset = first_date)
|
| 64 |
return preds[["pred_idx", "Group", "pred"]]
|
| 65 |
|
| 66 |
def adjust_data_for_plot(df, preds):
|
|
@@ -157,7 +154,7 @@ def main():
|
|
| 157 |
rain = st.selectbox("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
|
| 158 |
|
| 159 |
dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING)
|
| 160 |
-
preds = predict(model, dataloader
|
| 161 |
|
| 162 |
data_plot = adjust_data_for_plot(df.copy(), preds)
|
| 163 |
fig, _ = generate_plot(data_plot)
|
|
|
|
| 19 |
from PIL import Image
|
| 20 |
|
| 21 |
## Functions
|
| 22 |
+
def raw_preds_to_df(raw, quantiles = None):
|
| 23 |
"""
|
| 24 |
raw is output of model.predict with return_index=True
|
| 25 |
quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
|
|
|
|
| 31 |
dec_len = raw.output.prediction.shape[1]
|
| 32 |
n_quantiles = preds.shape[-1]
|
| 33 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
| 34 |
+
preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
|
| 35 |
+
preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
|
| 36 |
+
preds_df = preds_df.assign(pred=preds.flatten().cpu().numpy())
|
| 37 |
if quantiles is not None:
|
| 38 |
+
preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})
|
| 39 |
|
| 40 |
+
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
|
|
|
|
|
|
| 41 |
return preds_df
|
| 42 |
|
| 43 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
|
|
|
| 55 |
df = TimeSeriesDataSet.from_parameters(_parameters, df)
|
| 56 |
return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
|
| 57 |
|
| 58 |
+
def predict(_model, _dataloader):
|
| 59 |
+
out = model.predict(_dataloader, mode="raw", return_x=True,return_index=True, trainer_kwargs=dict(accelerator="cpu"))
|
| 60 |
+
preds = raw_preds_to_df(out)
|
|
|
|
| 61 |
return preds[["pred_idx", "Group", "pred"]]
|
| 62 |
|
| 63 |
def adjust_data_for_plot(df, preds):
|
|
|
|
| 154 |
rain = st.selectbox("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain")
|
| 155 |
|
| 156 |
dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING)
|
| 157 |
+
preds = predict(model, dataloader)
|
| 158 |
|
| 159 |
data_plot = adjust_data_for_plot(df.copy(), preds)
|
| 160 |
fig, _ = generate_plot(data_plot)
|