Spaces:
Runtime error
Runtime error
bug fix
Browse files
app.py
CHANGED
|
@@ -85,21 +85,29 @@ def generate_plot(df, dates, preds):
|
|
| 85 |
# Adjust spacing between subplots
|
| 86 |
plt.tight_layout()
|
| 87 |
|
| 88 |
-
for ax in axs.flat:
|
| 89 |
-
|
| 90 |
-
|
| 91 |
|
| 92 |
st.pyplot(fig)
|
| 93 |
-
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
with open('data/parameters.pkl', 'rb') as f:
|
| 98 |
parameters = pickle.load(f)
|
| 99 |
-
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
| 100 |
-
|
| 101 |
df = pd.read_pickle('data/test_data.pkl')
|
| 102 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
# Start App
|
| 105 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|
|
|
|
| 85 |
# Adjust spacing between subplots
|
| 86 |
plt.tight_layout()
|
| 87 |
|
| 88 |
+
#for ax in axs.flat:
|
| 89 |
+
# ax.set_xlim(df['Date'].min(), df['Date'].max())
|
| 90 |
+
# ax.set_ylim(df['sales'].min(), df['sales'].max())
|
| 91 |
|
| 92 |
st.pyplot(fig)
|
|
|
|
| 93 |
|
| 94 |
+
@st.cache_data
|
| 95 |
+
def load_data():
|
| 96 |
with open('data/parameters.pkl', 'rb') as f:
|
| 97 |
parameters = pickle.load(f)
|
|
|
|
|
|
|
| 98 |
df = pd.read_pickle('data/test_data.pkl')
|
| 99 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
| 100 |
+
return parameters, df
|
| 101 |
+
|
| 102 |
+
@st.cache_resource
|
| 103 |
+
def init_model():
|
| 104 |
+
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
| 105 |
+
return model
|
| 106 |
+
|
| 107 |
+
def main():
|
| 108 |
+
## Initiate Data
|
| 109 |
+
parameters, df = load_data()
|
| 110 |
+
model = init_model()
|
| 111 |
|
| 112 |
# Start App
|
| 113 |
st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
|