FC / main.py
jahanking's picture
Add model with Git LFS
94e03df
#python -m uvicorn main:app --reload
#GET - http://127.0.0.1:8000/predict-sales/
#http://localhost:8000/forecast-product/
# { "product_id": "P0002" }
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from statsmodels.tsa.arima.model import ARIMA
import matplotlib.pyplot as plt
import uvicorn
import warnings
warnings.filterwarnings("ignore")
# === LSTM Model ===
class LSTMResidual(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(1, 64, num_layers=2, batch_first=True)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(64, 1)
def forward(self, x):
out, _ = self.lstm(x)
out = self.dropout(out[:, -1, :])
return self.fc(out)
# === Load Assets ===
df = pd.read_csv("Name_Updated_Retail_Inventory.csv")
df.columns = df.columns.str.strip()
df['Product ID'] = df['Product ID'].astype(str).str.strip()
df['Date'] = pd.to_datetime(df['Date'], format="%d-%m-%Y")
checkpoint = torch.load("hybrid_inventory_model.pt", weights_only=False)
model = LSTMResidual()
model.load_state_dict(checkpoint['lstm_state_dict'])
model.eval()
scaler = checkpoint['scaler']
# === FastAPI App ===
app = FastAPI(title="πŸ“¦ Inventory Forecast API")
# === Request Schema ===
class ProductRequest(BaseModel):
product_id: str
# === 1. Single Product Forecast Endpoint ===
@app.post("/forecast-product/")
def forecast_product(request: ProductRequest):
product_id = request.product_id.strip()
if product_id not in df['Product ID'].unique():
raise HTTPException(status_code=404, detail=f"❌ Invalid Product ID '{product_id}'.")
product_data = df[df['Product ID'] == product_id].sort_values('Date')
if product_data['Date'].nunique() < 50:
raise HTTPException(status_code=400, detail=f"❌ Not enough unique days of data for Product ID '{product_id}'.")
try:
ts = product_data.groupby('Date')['Inventory Level'].mean().asfreq('D').fillna(method='ffill')
if ts.empty or len(ts) < 10:
raise ValueError("❌ Inventory time series too short for ARIMA.")
arima_model = ARIMA(ts, order=(2, 1, 1))
arima_result = arima_model.fit()
arima_forecast = arima_result.forecast(steps=30)
residuals = arima_result.resid.dropna()
if len(residuals) < 10:
raise ValueError("❌ Not enough residuals for LSTM.")
residuals_scaled = scaler.transform(residuals.values.reshape(-1, 1))
# LSTM Forecast
seq_len = 7
next_30 = []
last_input = residuals_scaled[-seq_len:].copy()
for _ in range(30):
input_tensor = torch.tensor(last_input.reshape(1, seq_len, 1), dtype=torch.float32)
with torch.no_grad():
pred = model(input_tensor).item()
next_30.append(pred)
last_input = np.append(last_input[1:], [[pred]], axis=0)
predicted_residuals_unscaled = scaler.inverse_transform(np.array(next_30).reshape(-1, 1)).flatten()
final_forecast_30 = arima_forecast.values + predicted_residuals_unscaled
estimated_sales = (final_forecast_30[:-1] - final_forecast_30[1:]).sum()
name = product_data['Product Name'].iloc[0]
# Save Plot
plt.figure(figsize=(12, 5))
plt.plot(ts[-60:], label="Historical Inventory", color='skyblue')
future_dates = pd.date_range(start=ts.index[-1] + pd.Timedelta(days=1), periods=30)
plt.plot(future_dates, final_forecast_30, label="Forecast", color='orange')
plt.title(f"Inventory Forecast for {name}")
plt.xlabel("Date")
plt.ylabel("Inventory Level")
plt.legend()
plt.tight_layout()
plt.savefig("forecast.png")
plt.close()
return {
"Product ID": product_id,
"Product Name": name,
"Inventory Range (Next 30 Days)": {
"Min": round(final_forecast_30.min(), 2),
"Max": round(final_forecast_30.max(), 2)
},
"Estimated Sales (Next 30 Days)": round(estimated_sales, 2),
"Download Forecast Plot": "http://localhost:8000/forecast-plot/"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"❌ Forecast failed: {str(e)}")
# === 2. All Products Summary Endpoint ===
@app.get("/predict-sales/")
def predict_sales():
predicted_sales_summary = []
for product_id in df['Product ID'].unique():
product_data = df[df['Product ID'] == product_id].sort_values('Date')
ts = product_data.groupby('Date')['Inventory Level'].mean().asfreq('D').fillna(method='ffill')
if len(ts) < 50:
continue
try:
arima_model = ARIMA(ts, order=(2,1,1))
arima_result = arima_model.fit()
arima_forecast = arima_result.forecast(steps=30)
residuals = arima_result.resid.dropna()
residuals_scaled = scaler.transform(residuals.values.reshape(-1, 1))
# LSTM Forecast
seq_len = 7
next_30 = []
last_input = residuals_scaled[-seq_len:].copy()
for _ in range(30):
input_tensor = torch.tensor(last_input.reshape(1, seq_len, 1), dtype=torch.float32)
with torch.no_grad():
pred = model(input_tensor).item()
next_30.append(pred)
last_input = np.append(last_input[1:], [[pred]], axis=0)
predicted_residuals_unscaled = scaler.inverse_transform(np.array(next_30).reshape(-1, 1)).flatten()
final_forecast_30 = arima_forecast.values.flatten() + predicted_residuals_unscaled
estimated_sales = final_forecast_30[:-1] - final_forecast_30[1:]
total_sales = estimated_sales.sum()
name = product_data['Product Name'].iloc[0]
predicted_sales_summary.append({
'Product ID': product_id,
'Product Name': name,
'Predicted Total Sales (Units)': round(total_sales, 2)
})
except Exception as e:
print(f"❌ Skipping {product_id} due to error: {e}")
continue
result_df = pd.DataFrame(predicted_sales_summary).sort_values(by='Predicted Total Sales (Units)', ascending=False)
return {
"top_5": result_df.head(5).to_dict(orient="records"),
"bottom_5": result_df.tail(5).to_dict(orient="records"),
"all_predictions": result_df.to_dict(orient="records")
}
# === 3. Serve Plot Image ===
@app.get("/forecast-plot/")
def get_forecast_plot():
return FileResponse("forecast.png", media_type="image/png", filename="forecast.png")
# === Server Runner ===
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)