Spaces:
Runtime error
Runtime error
bug fix
Browse files- README.md +0 -1
- app.py +30 -26
- requirements.txt +8 -6
README.md
CHANGED
|
@@ -5,7 +5,6 @@ colorFrom: blue
|
|
| 5 |
colorTo: pink
|
| 6 |
sdk: streamlit
|
| 7 |
sdk_version: 1.21.0
|
| 8 |
-
python_version: 3.10.0
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: mit
|
|
|
|
| 5 |
colorTo: pink
|
| 6 |
sdk: streamlit
|
| 7 |
sdk_version: 1.21.0
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
app.py
CHANGED
|
@@ -65,6 +65,32 @@ def predict(model, dataloader):
|
|
| 65 |
#preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
|
| 66 |
#preds.drop(columns=['time_idx_y'],inplace=True)
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
def main():
|
| 69 |
## Initiate Data
|
| 70 |
with open('data/parameters.pkl', 'rb') as f:
|
|
@@ -72,7 +98,7 @@ def main():
|
|
| 72 |
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
| 73 |
|
| 74 |
df = pd.read_pickle('data/test_data.pkl')
|
| 75 |
-
df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
|
| 76 |
|
| 77 |
rain_mapping = {
|
| 78 |
"Yes" : 1,
|
|
@@ -110,31 +136,9 @@ def main():
|
|
| 110 |
|
| 111 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
|
| 117 |
-
axs[0, 0].set_title('Article Group 1')
|
| 118 |
-
|
| 119 |
-
axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
|
| 120 |
-
axs[0, 1].set_title('Article Group 2')
|
| 121 |
-
|
| 122 |
-
axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
|
| 123 |
-
axs[1, 0].set_title('Article Group 3')
|
| 124 |
-
|
| 125 |
-
axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
|
| 126 |
-
axs[1, 1].set_title('Article Group 4')
|
| 127 |
-
|
| 128 |
-
# Adjust spacing between subplots
|
| 129 |
-
plt.tight_layout()
|
| 130 |
-
|
| 131 |
-
for ax in axs.flat:
|
| 132 |
-
ax.set_xlim(df['Date'].min(), df['Date'].max())
|
| 133 |
-
ax.set_ylim(df['sales'].min(), df['sales'].max())
|
| 134 |
-
|
| 135 |
-
st.pyplot(fig)
|
| 136 |
-
|
| 137 |
-
st.button("Forecast Sales", type="primary")
|
| 138 |
|
| 139 |
if __name__ == '__main__':
|
| 140 |
main()
|
|
|
|
| 65 |
#preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
|
| 66 |
#preds.drop(columns=['time_idx_y'],inplace=True)
|
| 67 |
|
| 68 |
+
def generate_plot(df, predictions):
|
| 69 |
+
fig, axs = plt.subplots(2, 2, figsize=(8, 6))
|
| 70 |
+
|
| 71 |
+
# Plot scatter plots for each group
|
| 72 |
+
axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
|
| 73 |
+
axs[0, 0].set_title('Article Group 1')
|
| 74 |
+
|
| 75 |
+
axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
|
| 76 |
+
axs[0, 1].set_title('Article Group 2')
|
| 77 |
+
|
| 78 |
+
axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
|
| 79 |
+
axs[1, 0].set_title('Article Group 3')
|
| 80 |
+
|
| 81 |
+
axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
|
| 82 |
+
axs[1, 1].set_title('Article Group 4')
|
| 83 |
+
|
| 84 |
+
# Adjust spacing between subplots
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
|
| 87 |
+
for ax in axs.flat:
|
| 88 |
+
ax.set_xlim(df['Date'].min(), df['Date'].max())
|
| 89 |
+
ax.set_ylim(df['sales'].min(), df['sales'].max())
|
| 90 |
+
|
| 91 |
+
st.pyplot(fig)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
def main():
|
| 95 |
## Initiate Data
|
| 96 |
with open('data/parameters.pkl', 'rb') as f:
|
|
|
|
| 98 |
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
| 99 |
|
| 100 |
df = pd.read_pickle('data/test_data.pkl')
|
| 101 |
+
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
| 102 |
|
| 103 |
rain_mapping = {
|
| 104 |
"Yes" : 1,
|
|
|
|
| 136 |
|
| 137 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
| 138 |
|
| 139 |
+
if st.button("Forecast Sales", type="primary"):
|
| 140 |
+
converted_data = prepare_dataset(parameters, df, rain, temperature, datepicker)
|
| 141 |
+
generate_plot(converted_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
if __name__ == '__main__':
|
| 144 |
main()
|
requirements.txt
CHANGED
|
@@ -22,7 +22,7 @@ decorator==5.1.1
|
|
| 22 |
executing==1.2.0
|
| 23 |
fonttools==4.38.0
|
| 24 |
frozenlist==1.3.3
|
| 25 |
-
fsspec
|
| 26 |
future==0.18.3
|
| 27 |
google-auth==2.16.0
|
| 28 |
google-auth-oauthlib==0.4.6
|
|
@@ -49,7 +49,7 @@ multidict==6.0.4
|
|
| 49 |
nest-asyncio==1.5.6
|
| 50 |
numpy==1.23.5
|
| 51 |
oauthlib==3.2.2
|
| 52 |
-
optuna
|
| 53 |
packaging==23.0
|
| 54 |
pandas==1.5.2
|
| 55 |
parso==0.8.3
|
|
@@ -72,8 +72,8 @@ Pygments==2.14.0
|
|
| 72 |
pyparsing==3.0.9
|
| 73 |
pyperclip==1.8.2
|
| 74 |
python-dateutil==2.8.2
|
| 75 |
-
pytorch-forecasting
|
| 76 |
-
pytorch-lightning
|
| 77 |
pytz==2022.7.1
|
| 78 |
PyYAML==6.0
|
| 79 |
pyzmq==25.0.0
|
|
@@ -81,7 +81,7 @@ requests==2.28.2
|
|
| 81 |
requests-futures==1.0.0
|
| 82 |
requests-oauthlib==1.3.1
|
| 83 |
rsa==4.9
|
| 84 |
-
scikit-learn==1.
|
| 85 |
scipy==1.10.0
|
| 86 |
six==1.16.0
|
| 87 |
SQLAlchemy==1.4.46
|
|
@@ -93,8 +93,10 @@ tensorboard-data-server==0.6.1
|
|
| 93 |
tensorboard-plugin-wit==1.8.1
|
| 94 |
tensorboardX==2.5.1
|
| 95 |
threadpoolctl==3.1.0
|
| 96 |
-
torch
|
|
|
|
| 97 |
torchmetrics==0.11.0
|
|
|
|
| 98 |
tornado==6.2
|
| 99 |
tqdm==4.64.1
|
| 100 |
traitlets==5.9.0
|
|
|
|
| 22 |
executing==1.2.0
|
| 23 |
fonttools==4.38.0
|
| 24 |
frozenlist==1.3.3
|
| 25 |
+
fsspec==2022.11.0
|
| 26 |
future==0.18.3
|
| 27 |
google-auth==2.16.0
|
| 28 |
google-auth-oauthlib==0.4.6
|
|
|
|
| 49 |
nest-asyncio==1.5.6
|
| 50 |
numpy==1.23.5
|
| 51 |
oauthlib==3.2.2
|
| 52 |
+
optuna==2.10.1
|
| 53 |
packaging==23.0
|
| 54 |
pandas==1.5.2
|
| 55 |
parso==0.8.3
|
|
|
|
| 72 |
pyparsing==3.0.9
|
| 73 |
pyperclip==1.8.2
|
| 74 |
python-dateutil==2.8.2
|
| 75 |
+
pytorch-forecasting==0.10.3
|
| 76 |
+
pytorch-lightning==1.9.0
|
| 77 |
pytz==2022.7.1
|
| 78 |
PyYAML==6.0
|
| 79 |
pyzmq==25.0.0
|
|
|
|
| 81 |
requests-futures==1.0.0
|
| 82 |
requests-oauthlib==1.3.1
|
| 83 |
rsa==4.9
|
| 84 |
+
scikit-learn==1.1.3
|
| 85 |
scipy==1.10.0
|
| 86 |
six==1.16.0
|
| 87 |
SQLAlchemy==1.4.46
|
|
|
|
| 93 |
tensorboard-plugin-wit==1.8.1
|
| 94 |
tensorboardX==2.5.1
|
| 95 |
threadpoolctl==3.1.0
|
| 96 |
+
torch==1.10.2
|
| 97 |
+
torchaudio==0.10.2
|
| 98 |
torchmetrics==0.11.0
|
| 99 |
+
torchvision==0.11.3
|
| 100 |
tornado==6.2
|
| 101 |
tqdm==4.64.1
|
| 102 |
traitlets==5.9.0
|