virustechhacks's picture
Upload app.py with huggingface_hub
69e1d66 verified
import os
import json
import logging
from datetime import datetime, timedelta
from contextlib import asynccontextmanager
from typing import Optional
import numpy as np
import pandas as pd
import torch
import httpx
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ── Global state ────────────────────────────────────────────
MODEL = None
CONFIG = None
DATASET_PARAMS = None # saved from checkpoint for TimeSeriesDataSet.from_parameters()
OPENWEATHER_KEY = os.getenv("OPENWEATHER_API_KEY", "")
ORIGIN_DATE = datetime(2002, 9, 17)
NODES = [
"Ahmedabad Cold Storage",
"Chennai Port",
"Delhi DC",
"Mumbai Hub",
"Pune Warehouse",
]
NODE_CITIES = {
"Ahmedabad Cold Storage": "Ahmedabad",
"Chennai Port": "Chennai",
"Delhi DC": "Delhi",
"Mumbai Hub": "Mumbai",
"Pune Warehouse": "Pune",
}
# demand baseline (0-1 normalized) and unit scale per node
NODE_BASELINE = {
"Mumbai Hub": {"demand": 0.45, "scale": 500},
"Pune Warehouse": {"demand": 0.35, "scale": 350},
"Ahmedabad Cold Storage": {"demand": 0.28, "scale": 300},
"Delhi DC": {"demand": 0.62, "scale": 750},
"Chennai Port": {"demand": 0.38, "scale": 400},
}
# units below which we flag a reorder alert
SAFETY_STOCK = {
"Mumbai Hub": 280,
"Pune Warehouse": 200,
"Ahmedabad Cold Storage": 240,
"Delhi DC": 500,
"Chennai Port": 300,
}
FALLBACK_UNITS = {
"Mumbai Hub": {"predicted": 225, "lower": 196, "upper": 259},
"Chennai Port": {"predicted": 264, "lower": 230, "upper": 304},
"Delhi DC": {"predicted": 466, "lower": 405, "upper": 536},
"Ahmedabad Cold Storage": {"predicted": 221, "lower": 192, "upper": 254},
"Pune Warehouse": {"predicted": 176, "lower": 153, "upper": 202},
}
# ── India festival calendar ──────────────────────────────────
DIWALI = {
2023: (11, 12), 2024: (11, 1), 2025: (10, 20),
2026: (11, 8), 2027: (10, 29), 2028: (10, 17),
}
# (month, day, window_days, multiplier, label)
FESTIVALS = [
(1, 14, 2, 1.15, "Pongal / Makar Sankranti"),
(1, 26, 1, 1.05, "Republic Day"),
(3, 14, 3, 1.25, "Holi"),
(3, 25, 3, 1.25, "Holi"), # alt year
(3, 30, 3, 1.30, "Eid al-Fitr"), # approximate
(4, 10, 3, 1.30, "Eid al-Fitr"), # alt year
(8, 15, 1, 1.10, "Independence Day"),
(9, 5, 5, 1.20, "Onam"),
(10, 2, 9, 1.20, "Navratri"),
(10, 10, 5, 1.20, "Durga Puja"),
(12, 25, 2, 1.10, "Christmas"),
(12, 31, 2, 1.10, "New Year Eve"),
(1, 1, 1, 1.08, "New Year"),
]
def festival_effect(dt: datetime) -> tuple[float, Optional[str]]:
year, month = dt.year, dt.month
if year in DIWALI:
dm, dd = DIWALI[year]
diwali_dt = datetime(year, dm, dd)
if abs((dt - diwali_dt).days) <= 7:
return 1.40, "Diwali"
for fm, fd, window, mult, label in FESTIVALS:
try:
f_dt = datetime(dt.year, fm, fd)
except ValueError:
continue
if abs((dt - f_dt).days) <= window:
return mult, label
return 1.0, None
def seasonal_effect(dt: datetime) -> tuple[float, Optional[str]]:
month = dt.month
if 6 <= month <= 9:
return 0.85, "monsoon"
elif month == 11:
return 1.35, "post-monsoon peak"
elif 10 <= month <= 12:
return 1.10, "harvest season"
elif month in (1, 2):
return 0.92, "winter slowdown"
return 1.0, None
# ── Weather enrichment ────────────────────────────────────────
_weather_cache: dict[str, tuple[float, dict]] = {}
WEATHER_TTL = 300 # seconds
async def fetch_weather(city: str) -> dict:
now = datetime.now().timestamp()
if city in _weather_cache:
ts, cached = _weather_cache[city]
if now - ts < WEATHER_TTL:
return cached
if not OPENWEATHER_KEY:
return {}
try:
async with httpx.AsyncClient(timeout=5.0) as client:
r = await client.get(
"https://api.openweathermap.org/data/2.5/weather",
params={
"q": f"{city},IN",
"appid": OPENWEATHER_KEY,
"units": "metric",
},
)
if r.status_code == 200:
raw = r.json()
data = {
"temp_c": raw["main"]["temp"],
"rain_mm": raw.get("rain", {}).get("3h", 0.0),
"condition": raw["weather"][0]["main"],
"description": raw["weather"][0]["description"],
"humidity": raw["main"]["humidity"],
}
_weather_cache[city] = (now, data)
return data
except Exception as e:
logger.warning(f"Weather fetch failed for {city}: {e}")
return {}
def weather_effect(weather: dict) -> tuple[float, Optional[str]]:
if not weather:
return 1.0, None
temp = weather.get("temp_c", 25)
rain = weather.get("rain_mm", 0.0)
cond = weather.get("condition", "")
if rain > 15 or cond == "Thunderstorm":
return 1.18, f"heavy rain ({rain:.0f} mm) β€” stockpiling spike"
elif rain > 5:
return 1.08, f"moderate rain ({rain:.0f} mm)"
elif temp > 42:
return 1.12, f"extreme heat ({temp:.0f}Β°C) β€” cooling goods spike"
elif temp < 5:
return 1.08, f"cold wave ({temp:.0f}Β°C) β€” heating goods spike"
elif cond == "Fog" and temp < 15:
return 0.93, f"dense fog ({temp:.0f}Β°C) β€” logistics delay expected"
return 1.0, None
# ── Model loading ─────────────────────────────────────────────
def load_model() -> bool:
global MODEL, CONFIG, DATASET_PARAMS
try:
from pytorch_forecasting import TemporalFusionTransformer
ckpt_path = "artifacts/tft_final.ckpt"
config_path = "artifacts/tft_config.json"
if not os.path.exists(ckpt_path):
logger.error(f"Checkpoint not found: {ckpt_path}")
return False
with open(config_path) as f:
CONFIG = json.load(f)
raw = torch.load(ckpt_path, map_location="cpu", weights_only=False)
hp = raw["hyper_parameters"]
# Save dataset_parameters before stripping β€” needed for TimeSeriesDataSet.from_parameters()
DATASET_PARAMS = hp.get("dataset_parameters") or raw.get("dataset_parameters")
# Only strip hparams that are truly incompatible with this version.
# Keys not in TFT's explicit params but NOT in this strip-list pass through
# **kwargs β†’ BaseModel (e.g. output_transformer, weight_decay, optimizer_params).
STRIP_KEYS = {
"monotone_constraints", # renamed to monotone_constaints (typo) in 1.1.1
"mask_bias", # added in newer version, unknown to 1.1.1
"reduce_on_plateau_reduction",
"reduce_on_plateau_min_lr",
"dataset_parameters", # stored separately in DATASET_PARAMS
}
for key in STRIP_KEYS:
if key in hp:
logger.info(f"Stripping incompatible hparam: {key}")
hp.pop(key)
MODEL = TemporalFusionTransformer(**hp)
MODEL.load_state_dict(raw["state_dict"])
MODEL.eval()
logger.info(f"Model loaded β€” MAPE {CONFIG.get('overall_mape_pct')}%")
return True
except Exception as e:
logger.error(f"Model load failed: {e}")
return False
# ── TFT synthetic-history inference ──────────────────────────
def _make_history_df(node_name: str, end_date: datetime, horizon: int) -> pd.DataFrame:
"""Generate 60-day encoder context + horizon decoder rows for TFT."""
total = 60 + horizon
base_d = NODE_BASELINE.get(node_name, {"demand": 0.4})["demand"]
rows = []
rng = np.random.default_rng(seed=abs(hash(node_name)) % (2**31))
for i in range(total):
dt = end_date - timedelta(days=total - 1 - i)
month = dt.month
year = dt.year
dm, dd = DIWALI.get(year, (11, 1))
diwali_dt = datetime(year, dm, dd)
is_diwali = 1.0 if abs((dt - diwali_dt).days) <= 7 else 0.0
is_monsoon = 1.0 if 6 <= month <= 9 else 0.0
is_harvest = 1.0 if 10 <= month <= 12 else 0.0
seasonal = 1.0
if is_monsoon: seasonal *= 0.85
if is_harvest: seasonal *= 1.10
if is_diwali: seasonal *= 1.35
demand = float(np.clip(base_d * seasonal + rng.normal(0, 0.03), 0.05, 0.95))
volatility = float(np.clip(0.20 + rng.normal(0, 0.04), 0.02, 0.80))
rows.append({
"node_name": node_name,
"time_idx": int((dt - ORIGIN_DATE).days),
"day_of_week": str(dt.weekday()),
"month_str": str(month),
"is_monsoon": is_monsoon,
"is_diwali_week": is_diwali,
"is_harvest_season": is_harvest,
"demand_index": demand,
"price_volatility": volatility,
})
return pd.DataFrame(rows)
_inference_cache: dict[str, tuple[str, list, list, list]] = {} # node -> (date_key, med, lo, hi)
def _sanitize_dataset_params(dp: dict) -> dict:
"""Ensure all collection-type fields in dataset_parameters are not None."""
list_fields = [
"static_categoricals", "static_reals",
"time_varying_known_categoricals", "time_varying_known_reals",
"time_varying_unknown_categoricals", "time_varying_unknown_reals",
"variable_groups",
]
dict_fields = ["lags", "constant_fill_strategy", "categorical_encoders", "scalers"]
patched = dict(dp)
for field in list_fields:
if patched.get(field) is None:
patched[field] = []
for field in dict_fields:
if patched.get(field) is None:
patched[field] = {}
return patched
def run_tft_inference(node_name: str, horizon: int) -> tuple[list[int], list[int], list[int]]:
"""Return (median, lower_q10, upper_q90) unit lists via real TFT forward pass."""
from pytorch_forecasting import TimeSeriesDataSet
cache_key = f"{node_name}:{datetime.now().strftime('%Y-%m-%d')}:{horizon}"
if node_name in _inference_cache and _inference_cache[node_name][0] == cache_key:
_, med, lo, hi = _inference_cache[node_name]
return med, lo, hi
h = min(horizon, 30)
df = _make_history_df(node_name, datetime.now(), h)
dataset = TimeSeriesDataSet.from_parameters(
_sanitize_dataset_params(DATASET_PARAMS or {}), df,
predict=True,
stop_randomization=True,
)
loader = dataset.to_dataloader(train=False, batch_size=1, num_workers=0)
with torch.no_grad():
preds = MODEL.predict(loader, mode="quantiles") # [1, horizon, 3]
q = preds[0].cpu().numpy() # [horizon, 3] β†’ q10, q50, q90
scale = NODE_BASELINE.get(node_name, {"scale": 400})["scale"]
lo = [max(0, int(q[i, 0] * scale)) for i in range(h)]
med = [max(0, int(q[i, 1] * scale)) for i in range(h)]
hi = [max(0, int(q[i, 2] * scale)) for i in range(h)]
_inference_cache[node_name] = (cache_key, med, lo, hi)
return med, lo, hi
# ── Fallback (seasonal + noise, no model) ─────────────────────
def _fallback_series(node_name: str, horizon: int) -> tuple[list[int], list[int], list[int]]:
fb = FALLBACK_UNITS.get(node_name, {"predicted": 200, "lower": 174, "upper": 230})
rng = np.random.default_rng(seed=42)
med, lo, hi = [], [], []
for d in range(horizon):
dt = datetime.now() + timedelta(days=d + 1)
s_mult, _ = seasonal_effect(dt)
f_mult, _ = festival_effect(dt)
mult = s_mult * f_mult
noise = int(rng.normal(0, fb["predicted"] * 0.03))
p = max(0, int(fb["predicted"] * mult) + noise)
med.append(p)
lo.append(max(0, int(p * 0.87)))
hi.append(int(p * 1.15))
return med, lo, hi
# ── FastAPI app ────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Loading TFT model...")
ok = load_model()
logger.info("Model ready" if ok else "Fallback mode active")
yield
logger.info("Shutting down")
app = FastAPI(
title = "SmartChain Forecasting API",
description = "TFT demand forecasting for Indian supply chain nodes. MAPE: 1.79%",
version = "2.0.0",
lifespan = lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Schemas ────────────────────────────────────────────────────
class ForecastRequest(BaseModel):
node_name: str
horizon_days: int = 30
class ForecastPoint(BaseModel):
date: str
node_name: str
predicted_units: int
lower_bound: int
upper_bound: int
confidence_pct: float
reorder_alert: bool
class ForecastResponse(BaseModel):
model_config = {"protected_namespaces": ()}
node_name: str
horizon_days: int
forecasts: list[ForecastPoint]
model_mape_pct: Optional[float]
source: str
reorder_alert: bool
alert_reason: Optional[str]
class EnrichedForecastPoint(BaseModel):
date: str
node_name: str
predicted_units: int
lower_bound: int
upper_bound: int
confidence_pct: float
reorder_alert: bool
seasonal_factor: float
festival_factor: float
weather_factor: float
festival_name: Optional[str]
weather_reason: Optional[str]
season_label: Optional[str]
class WeatherSnapshot(BaseModel):
city: str
temp_c: Optional[float]
condition: Optional[str]
rain_mm: Optional[float]
description: Optional[str]
class EnrichedForecastResponse(BaseModel):
model_config = {"protected_namespaces": ()}
node_name: str
horizon_days: int
forecasts: list[EnrichedForecastPoint]
model_mape_pct: Optional[float]
source: str
weather: Optional[WeatherSnapshot]
reorder_alert: bool
alert_reason: Optional[str]
class HealthResponse(BaseModel):
status: str
model_loaded: bool
dataset_params: bool
overall_mape_pct: Optional[float]
nodes: list[str]
weather_enabled: bool
# ── Shared forecast builder ────────────────────────────────────
def _build_forecasts(
node_name: str,
horizon: int,
weather: dict,
enriched: bool,
) -> tuple[list, str, bool, Optional[str]]:
"""
Returns (points, source, reorder_alert, alert_reason).
points are EnrichedForecastPoint if enriched=True else ForecastPoint.
"""
source = "model"
med = lo = hi = None
if MODEL is not None and DATASET_PARAMS is not None:
try:
med, lo, hi = run_tft_inference(node_name, horizon)
except Exception as e:
logger.warning(f"TFT inference failed for {node_name}: {e} β€” fallback")
source = "fallback"
else:
source = "fallback"
if med is None:
med, lo, hi = _fallback_series(node_name, horizon)
safety = SAFETY_STOCK.get(node_name, 250)
w_mult, w_reason = weather_effect(weather)
points = []
reorder_flag = False
alert_reason = None
for d in range(horizon):
dt = datetime.now() + timedelta(days=d + 1)
s_mult, s_label = seasonal_effect(dt)
f_mult, f_label = festival_effect(dt)
# Apply weather + festival enrichment on top of model output
combined = s_mult * f_mult * w_mult
p = max(0, int(med[d] * combined))
p_lo = max(0, int(lo[d] * combined))
p_hi = max(0, int(hi[d] * combined))
conf = round(max(0.50, 0.97 - d * 0.004), 3)
alert = p_lo < safety
if alert and not reorder_flag:
reorder_flag = True
alert_reason = (
f"Lower bound ({p_lo} units) on {dt.strftime('%Y-%m-%d')} "
f"is below safety stock ({safety} units)"
)
if enriched:
points.append(EnrichedForecastPoint(
date = dt.strftime("%Y-%m-%d"),
node_name = node_name,
predicted_units = p,
lower_bound = p_lo,
upper_bound = p_hi,
confidence_pct = conf,
reorder_alert = alert,
seasonal_factor = round(s_mult, 3),
festival_factor = round(f_mult, 3),
weather_factor = round(w_mult, 3),
festival_name = f_label,
weather_reason = w_reason,
season_label = s_label,
))
else:
points.append(ForecastPoint(
date = dt.strftime("%Y-%m-%d"),
node_name = node_name,
predicted_units = p,
lower_bound = p_lo,
upper_bound = p_hi,
confidence_pct = conf,
reorder_alert = alert,
))
return points, source, reorder_flag, alert_reason
# ── Routes ─────────────────────────────────────────────────────
@app.get("/")
async def root():
return {
"name": "SmartChain Forecasting API",
"version": "2.0.0",
"status": "running",
"docs": "/docs",
"mape": "1.79%",
"endpoints": ["/health", "/nodes", "/forecast-demand", "/forecast-enriched"],
}
@app.get("/health", response_model=HealthResponse)
async def health():
return HealthResponse(
status = "ok",
model_loaded = MODEL is not None,
dataset_params = DATASET_PARAMS is not None,
overall_mape_pct = CONFIG.get("overall_mape_pct") if CONFIG else 1.79,
nodes = NODES,
weather_enabled = bool(OPENWEATHER_KEY),
)
@app.get("/nodes")
async def get_nodes():
return {"nodes": NODES, "safety_stock": SAFETY_STOCK}
@app.post("/forecast-demand", response_model=ForecastResponse)
async def forecast_demand(req: ForecastRequest):
_validate(req)
city = NODE_CITIES.get(req.node_name, "Mumbai")
weather = await fetch_weather(city)
points, source, reorder_alert, alert_reason = _build_forecasts(
req.node_name, req.horizon_days, weather, enriched=False
)
return ForecastResponse(
node_name = req.node_name,
horizon_days = req.horizon_days,
forecasts = points,
model_mape_pct = CONFIG.get("overall_mape_pct") if CONFIG else 1.79,
source = source,
reorder_alert = reorder_alert,
alert_reason = alert_reason,
)
@app.post("/forecast-enriched", response_model=EnrichedForecastResponse)
async def forecast_enriched(req: ForecastRequest):
_validate(req)
city = NODE_CITIES.get(req.node_name, "Mumbai")
weather = await fetch_weather(city)
points, source, reorder_alert, alert_reason = _build_forecasts(
req.node_name, req.horizon_days, weather, enriched=True
)
w_snap = None
if weather:
w_snap = WeatherSnapshot(
city = city,
temp_c = weather.get("temp_c"),
condition = weather.get("condition"),
rain_mm = weather.get("rain_mm"),
description = weather.get("description"),
)
return EnrichedForecastResponse(
node_name = req.node_name,
horizon_days = req.horizon_days,
forecasts = points,
model_mape_pct = CONFIG.get("overall_mape_pct") if CONFIG else 1.79,
source = source,
weather = w_snap,
reorder_alert = reorder_alert,
alert_reason = alert_reason,
)
def _validate(req: ForecastRequest):
if req.node_name not in NODES:
raise HTTPException(
status_code=400,
detail=f"Unknown node '{req.node_name}'. Valid: {NODES}",
)
if not (1 <= req.horizon_days <= 30):
raise HTTPException(
status_code=400,
detail="horizon_days must be 1–30",
)