AjaykumarPilla commited on
Commit
0107214
·
verified ·
1 Parent(s): 340a8cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -2,9 +2,10 @@ import streamlit as st
2
  import pandas as pd
3
  from prophet import Prophet
4
  from datetime import datetime, timedelta
 
5
 
6
  # Prepare data for Prophet
7
- def prepare_data(usage_series):
8
  end_date = datetime.now()
9
  start_date = end_date - timedelta(days=len(usage_series) - 1)
10
  dates = [start_date + timedelta(days=i) for i in range(len(usage_series))]
@@ -12,28 +13,28 @@ def prepare_data(usage_series):
12
  'ds': dates,
13
  'y': usage_series
14
  })
15
- prophet_df['cap'] = 30 # Lowered from 60 to 30
16
  prophet_df['floor'] = 0
17
  return prophet_df
18
 
19
- # Train or update Prophet model
20
- def train_model(usage_series):
21
- print("Training with changepoint=0.001, usage:", usage_series) # Debug
22
- prophet_data = prepare_data(usage_series)
23
  model = Prophet(
24
  yearly_seasonality=False,
25
  weekly_seasonality=True,
26
  daily_seasonality=True,
27
- changepoint_prior_scale=0.001, # Lowered from 0.002
28
  growth='logistic'
29
  )
30
- model.fit(prophet_data)
31
  return model
32
 
33
  # Function to make forecasts
34
  def make_forecast(model, periods):
35
  future = model.make_future_dataframe(periods=periods)
36
- future['cap'] = 30
37
  future['floor'] = 0
38
  forecast = model.predict(future)
39
  daily_forecasts = forecast['yhat'].tail(periods).tolist()
@@ -41,7 +42,7 @@ def make_forecast(model, periods):
41
  return round(sum(max(0, y) for y in daily_forecasts)) # Clip negative values
42
 
43
  # Function to validate input
44
- def validate_usage(usage_str):
45
  try:
46
  usage_list = [float(x) for x in usage_str.split(',')]
47
  if len(usage_list) != 60:
@@ -65,14 +66,14 @@ def main():
65
  current_stock = st.number_input("Current Stock", min_value=0, value=0)
66
 
67
  if st.button("Generate Forecast"):
68
- usage_list, error = validate_usage(usage_series)
69
  if error:
70
  st.error(error)
71
  return
72
  st.write("Debug: Input usage series:", usage_list) # Debug
73
 
74
  try:
75
- model = train_model(usage_list)
76
  except Exception as e:
77
  st.error(f"Error training model: {str(e)}")
78
  return
 
2
  import pandas as pd
3
  from prophet import Prophet
4
  from datetime import datetime, timedelta
5
+ import numpy as np
6
 
7
  # Prepare data for Prophet
8
+ def prepare_prophet_data(usage_series):
9
  end_date = datetime.now()
10
  start_date = end_date - timedelta(days=len(usage_series) - 1)
11
  dates = [start_date + timedelta(days=i) for i in range(len(usage_series))]
 
13
  'ds': dates,
14
  'y': usage_series
15
  })
16
+ prophet_df['cap'] = 60 # Max observed usage
17
  prophet_df['floor'] = 0
18
  return prophet_df
19
 
20
+ # Train or update Prophet model with user-provided usage series
21
+ def train_model_with_usage(usage_series):
22
+ print("Training with changepoint_prior_scale=0.002, usage:", usage_series) # Debug
23
+ prophet_df = prepare_prophet_data(usage_series)
24
  model = Prophet(
25
  yearly_seasonality=False,
26
  weekly_seasonality=True,
27
  daily_seasonality=True,
28
+ changepoint_prior_scale=0.002,
29
  growth='logistic'
30
  )
31
+ model.fit(prophet_df)
32
  return model
33
 
34
  # Function to make forecasts
35
  def make_forecast(model, periods):
36
  future = model.make_future_dataframe(periods=periods)
37
+ future['cap'] = 60
38
  future['floor'] = 0
39
  forecast = model.predict(future)
40
  daily_forecasts = forecast['yhat'].tail(periods).tolist()
 
42
  return round(sum(max(0, y) for y in daily_forecasts)) # Clip negative values
43
 
44
  # Function to validate input
45
+ def validate_usage_series(usage_str):
46
  try:
47
  usage_list = [float(x) for x in usage_str.split(',')]
48
  if len(usage_list) != 60:
 
66
  current_stock = st.number_input("Current Stock", min_value=0, value=0)
67
 
68
  if st.button("Generate Forecast"):
69
+ usage_list, error = validate_usage_series(usage_series)
70
  if error:
71
  st.error(error)
72
  return
73
  st.write("Debug: Input usage series:", usage_list) # Debug
74
 
75
  try:
76
+ model = train_model_with_usage(usage_list)
77
  except Exception as e:
78
  st.error(f"Error training model: {str(e)}")
79
  return