Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,9 +2,10 @@ import streamlit as st
|
|
| 2 |
import pandas as pd
|
| 3 |
from prophet import Prophet
|
| 4 |
from datetime import datetime, timedelta
|
|
|
|
| 5 |
|
| 6 |
# Prepare data for Prophet
|
| 7 |
-
def
|
| 8 |
end_date = datetime.now()
|
| 9 |
start_date = end_date - timedelta(days=len(usage_series) - 1)
|
| 10 |
dates = [start_date + timedelta(days=i) for i in range(len(usage_series))]
|
|
@@ -12,28 +13,28 @@ def prepare_data(usage_series):
|
|
| 12 |
'ds': dates,
|
| 13 |
'y': usage_series
|
| 14 |
})
|
| 15 |
-
prophet_df['cap'] =
|
| 16 |
prophet_df['floor'] = 0
|
| 17 |
return prophet_df
|
| 18 |
|
| 19 |
-
# Train or update Prophet model
|
| 20 |
-
def
|
| 21 |
-
print("Training with
|
| 22 |
-
|
| 23 |
model = Prophet(
|
| 24 |
yearly_seasonality=False,
|
| 25 |
weekly_seasonality=True,
|
| 26 |
daily_seasonality=True,
|
| 27 |
-
changepoint_prior_scale=0.
|
| 28 |
growth='logistic'
|
| 29 |
)
|
| 30 |
-
model.fit(
|
| 31 |
return model
|
| 32 |
|
| 33 |
# Function to make forecasts
|
| 34 |
def make_forecast(model, periods):
|
| 35 |
future = model.make_future_dataframe(periods=periods)
|
| 36 |
-
future['cap'] =
|
| 37 |
future['floor'] = 0
|
| 38 |
forecast = model.predict(future)
|
| 39 |
daily_forecasts = forecast['yhat'].tail(periods).tolist()
|
|
@@ -41,7 +42,7 @@ def make_forecast(model, periods):
|
|
| 41 |
return round(sum(max(0, y) for y in daily_forecasts)) # Clip negative values
|
| 42 |
|
| 43 |
# Function to validate input
|
| 44 |
-
def
|
| 45 |
try:
|
| 46 |
usage_list = [float(x) for x in usage_str.split(',')]
|
| 47 |
if len(usage_list) != 60:
|
|
@@ -65,14 +66,14 @@ def main():
|
|
| 65 |
current_stock = st.number_input("Current Stock", min_value=0, value=0)
|
| 66 |
|
| 67 |
if st.button("Generate Forecast"):
|
| 68 |
-
usage_list, error =
|
| 69 |
if error:
|
| 70 |
st.error(error)
|
| 71 |
return
|
| 72 |
st.write("Debug: Input usage series:", usage_list) # Debug
|
| 73 |
|
| 74 |
try:
|
| 75 |
-
model =
|
| 76 |
except Exception as e:
|
| 77 |
st.error(f"Error training model: {str(e)}")
|
| 78 |
return
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
from prophet import Prophet
|
| 4 |
from datetime import datetime, timedelta
|
| 5 |
+
import numpy as np
|
| 6 |
|
| 7 |
# Prepare data for Prophet
|
| 8 |
+
def prepare_prophet_data(usage_series):
|
| 9 |
end_date = datetime.now()
|
| 10 |
start_date = end_date - timedelta(days=len(usage_series) - 1)
|
| 11 |
dates = [start_date + timedelta(days=i) for i in range(len(usage_series))]
|
|
|
|
| 13 |
'ds': dates,
|
| 14 |
'y': usage_series
|
| 15 |
})
|
| 16 |
+
prophet_df['cap'] = 60 # Max observed usage
|
| 17 |
prophet_df['floor'] = 0
|
| 18 |
return prophet_df
|
| 19 |
|
| 20 |
+
# Train or update Prophet model with user-provided usage series
|
| 21 |
+
def train_model_with_usage(usage_series):
|
| 22 |
+
print("Training with changepoint_prior_scale=0.002, usage:", usage_series) # Debug
|
| 23 |
+
prophet_df = prepare_prophet_data(usage_series)
|
| 24 |
model = Prophet(
|
| 25 |
yearly_seasonality=False,
|
| 26 |
weekly_seasonality=True,
|
| 27 |
daily_seasonality=True,
|
| 28 |
+
changepoint_prior_scale=0.002,
|
| 29 |
growth='logistic'
|
| 30 |
)
|
| 31 |
+
model.fit(prophet_df)
|
| 32 |
return model
|
| 33 |
|
| 34 |
# Function to make forecasts
|
| 35 |
def make_forecast(model, periods):
|
| 36 |
future = model.make_future_dataframe(periods=periods)
|
| 37 |
+
future['cap'] = 60
|
| 38 |
future['floor'] = 0
|
| 39 |
forecast = model.predict(future)
|
| 40 |
daily_forecasts = forecast['yhat'].tail(periods).tolist()
|
|
|
|
| 42 |
return round(sum(max(0, y) for y in daily_forecasts)) # Clip negative values
|
| 43 |
|
| 44 |
# Function to validate input
|
| 45 |
+
def validate_usage_series(usage_str):
|
| 46 |
try:
|
| 47 |
usage_list = [float(x) for x in usage_str.split(',')]
|
| 48 |
if len(usage_list) != 60:
|
|
|
|
| 66 |
current_stock = st.number_input("Current Stock", min_value=0, value=0)
|
| 67 |
|
| 68 |
if st.button("Generate Forecast"):
|
| 69 |
+
usage_list, error = validate_usage_series(usage_series)
|
| 70 |
if error:
|
| 71 |
st.error(error)
|
| 72 |
return
|
| 73 |
st.write("Debug: Input usage series:", usage_list) # Debug
|
| 74 |
|
| 75 |
try:
|
| 76 |
+
model = train_model_with_usage(usage_list)
|
| 77 |
except Exception as e:
|
| 78 |
st.error(f"Error training model: {str(e)}")
|
| 79 |
return
|