test / app.py
AjaykumarPilla's picture
Update app.py
bc41d2b verified
raw
history blame
6.46 kB
import streamlit as st
import pandas as pd
import pickle
from prophet import Prophet
from datetime import datetime, timedelta
import numpy as np
import plotly.graph_objects as go
# Load the trained models (optional, for initialization or fallback)
@st.cache_resource
def load_model():
try:
with open('model.pkl', 'rb') as f:
models = pickle.load(f)
return models
except FileNotFoundError:
return None
# Prepare data for Prophet
def prepare_prophet_data(usage_series):
# Create a date range for the last 60 days
end_date = datetime.now()
start_date = end_date - timedelta(days=len(usage_series) - 1)
dates = [start_date + timedelta(days=i) for i in range(len(usage_series))]
# Create Prophet-compatible DataFrame
prophet_df = pd.DataFrame({
'ds': dates,
'y': usage_series
})
return prophet_df
# Train or update Prophet model with user-provided usage series
def train_model_with_usage(usage_series):
prophet_df = prepare_prophet_data(usage_series)
model = Prophet(
yearly_seasonality=False,
weekly_seasonality=True,
daily_seasonality=True,
changepoint_prior_scale=0.01
)
model.fit(prophet_df)
return model
# Function to make forecasts
def make_forecast(model, periods):
future = model.make_future_dataframe(periods=periods)
forecast = model.predict(future)
return round(forecast['yhat'].tail(periods).sum()) # Round to nearest integer
# Function to validate input
def validate_usage_series(usage_str):
try:
usage_list = [float(x) for x in usage_str.split(',')]
if len(usage_list) != 60:
return None, "Usage series must contain exactly 60 values."
if any(x < 0 for x in usage_list):
return None, "Usage values must be non-negative."
return usage_list, None
except:
return None, "Invalid usage series format. Please enter 60 comma-separated numbers."
# Main Streamlit app
def main():
st.title("SmartLab Consumables Forecast")
# Load pre-trained models (optional, for reference)
models = load_model()
# Input form
st.header("Input Parameters")
consumable_type = st.selectbox("Consumable Type", ['Filters', 'Reagents', 'Vials'])
usage_series = st.text_input("Last 60 Days Usage (comma-separated)", "")
current_stock = st.number_input("Current Stock", min_value=0, value=0)
if st.button("Generate Forecast"):
# Validate inputs
usage_list, error = validate_usage_series(usage_series)
if error:
st.error(error)
return
# Train a new model with the user-provided usage series
try:
model = train_model_with_usage(usage_list)
except Exception as e:
st.error(f"Error training model: {str(e)}")
return
# Forecast for 7, 14, and 30 days
forecast_7 = make_forecast(model, 7)
forecast_14 = make_forecast(model, 14)
forecast_30 = make_forecast(model, 30)
# Display forecasts
st.header("Forecast Results")
st.write(f"**7-Day Forecast**: {forecast_7} units")
st.write(f"**14-Day Forecast**: {forecast_14} units")
st.write(f"**30-Day Forecast**: {forecast_30} units")
# Threshold alerting
st.header("Threshold Alerts")
if current_stock < forecast_7:
st.warning(f"Alert: Current stock ({current_stock}) is below 7-day forecast ({forecast_7}). 🚩")
if current_stock < forecast_14:
st.warning(f"Alert: Current stock ({current_stock}) is below 14-day forecast ({forecast_14}). 🚩")
if current_stock < forecast_30:
st.warning(f"Alert: Current stock ({current_stock}) is below 30-day forecast ({forecast_30}). 🚩")
# Order suggestions
st.header("Order Suggestions")
order_7 = max(0, round(forecast_7 - current_stock)) # Round to nearest integer
order_14 = max(0, round(forecast_14 - current_stock)) # Round to nearest integer
order_30 = max(0, round(forecast_30 - current_stock)) # Round to nearest integer
st.write(f"**For 7 Days**: Order {order_7} additional units.")
st.write(f"**For 14 Days**: Order {order_14} additional units.")
st.write(f"**For 30 Days**: Order {order_30} additional units.")
# Graphical representation for forecast
st.header("Forecast Visualization")
forecast_data = pd.DataFrame({
'Period': ['7 Days', '14 Days', '30 Days'],
'Units': [forecast_7, forecast_14, forecast_30]
})
fig_forecast = go.Figure()
fig_forecast.add_trace(go.Scatter(
x=forecast_data['Period'],
y=forecast_data['Units'],
mode='lines+markers',
name='Forecasted Units',
line=dict(color='blue'),
marker=dict(size=10)
))
fig_forecast.update_layout(
title='Consumable Usage Forecast',
xaxis_title='Time Period',
yaxis_title='Units',
template='plotly_white'
)
st.plotly_chart(fig_forecast)
# Graphical representation for threshold alerts
st.header("Threshold Alerts Visualization")
alert_data = pd.DataFrame({
'Category': ['Current Stock', '7-Day Forecast', '14-Day Forecast', '30-Day Forecast'],
'Units': [current_stock, forecast_7, forecast_14, forecast_30],
'Alert': [
False,
current_stock < forecast_7,
current_stock < forecast_14,
current_stock < forecast_30
]
})
fig_alerts = go.Figure()
fig_alerts.add_trace(go.Bar(
x=alert_data['Category'],
y=alert_data['Units'],
marker_color=['green'] + ['red' if alert else 'blue' for alert in alert_data['Alert'][1:]],
text=[f"🚩" if alert else "" for alert in alert_data['Alert']],
textposition='auto'
))
fig_alerts.update_layout(
title='Stock vs Forecast with Alerts (🚩 indicates low stock)',
xaxis_title='Category',
yaxis_title='Units',
template='plotly_white'
)
st.plotly_chart(fig_alerts)
if __name__ == "__main__":
main()