| import os |
| import json |
| import logging |
| from datetime import datetime, timedelta |
| from contextlib import asynccontextmanager |
| from typing import Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import httpx |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| MODEL = None |
| CONFIG = None |
| DATASET_PARAMS = None |
|
|
| OPENWEATHER_KEY = os.getenv("OPENWEATHER_API_KEY", "") |
| ORIGIN_DATE = datetime(2002, 9, 17) |
|
|
| NODES = [ |
| "Ahmedabad Cold Storage", |
| "Chennai Port", |
| "Delhi DC", |
| "Mumbai Hub", |
| "Pune Warehouse", |
| ] |
|
|
| NODE_CITIES = { |
| "Ahmedabad Cold Storage": "Ahmedabad", |
| "Chennai Port": "Chennai", |
| "Delhi DC": "Delhi", |
| "Mumbai Hub": "Mumbai", |
| "Pune Warehouse": "Pune", |
| } |
|
|
| |
| NODE_BASELINE = { |
| "Mumbai Hub": {"demand": 0.45, "scale": 500}, |
| "Pune Warehouse": {"demand": 0.35, "scale": 350}, |
| "Ahmedabad Cold Storage": {"demand": 0.28, "scale": 300}, |
| "Delhi DC": {"demand": 0.62, "scale": 750}, |
| "Chennai Port": {"demand": 0.38, "scale": 400}, |
| } |
|
|
| |
| SAFETY_STOCK = { |
| "Mumbai Hub": 280, |
| "Pune Warehouse": 200, |
| "Ahmedabad Cold Storage": 240, |
| "Delhi DC": 500, |
| "Chennai Port": 300, |
| } |
|
|
| FALLBACK_UNITS = { |
| "Mumbai Hub": {"predicted": 225, "lower": 196, "upper": 259}, |
| "Chennai Port": {"predicted": 264, "lower": 230, "upper": 304}, |
| "Delhi DC": {"predicted": 466, "lower": 405, "upper": 536}, |
| "Ahmedabad Cold Storage": {"predicted": 221, "lower": 192, "upper": 254}, |
| "Pune Warehouse": {"predicted": 176, "lower": 153, "upper": 202}, |
| } |
|
|
|
|
| |
| DIWALI = { |
| 2023: (11, 12), 2024: (11, 1), 2025: (10, 20), |
| 2026: (11, 8), 2027: (10, 29), 2028: (10, 17), |
| } |
|
|
| |
| FESTIVALS = [ |
| (1, 14, 2, 1.15, "Pongal / Makar Sankranti"), |
| (1, 26, 1, 1.05, "Republic Day"), |
| (3, 14, 3, 1.25, "Holi"), |
| (3, 25, 3, 1.25, "Holi"), |
| (3, 30, 3, 1.30, "Eid al-Fitr"), |
| (4, 10, 3, 1.30, "Eid al-Fitr"), |
| (8, 15, 1, 1.10, "Independence Day"), |
| (9, 5, 5, 1.20, "Onam"), |
| (10, 2, 9, 1.20, "Navratri"), |
| (10, 10, 5, 1.20, "Durga Puja"), |
| (12, 25, 2, 1.10, "Christmas"), |
| (12, 31, 2, 1.10, "New Year Eve"), |
| (1, 1, 1, 1.08, "New Year"), |
| ] |
|
|
|
|
| def festival_effect(dt: datetime) -> tuple[float, Optional[str]]: |
| year, month = dt.year, dt.month |
|
|
| if year in DIWALI: |
| dm, dd = DIWALI[year] |
| diwali_dt = datetime(year, dm, dd) |
| if abs((dt - diwali_dt).days) <= 7: |
| return 1.40, "Diwali" |
|
|
| for fm, fd, window, mult, label in FESTIVALS: |
| try: |
| f_dt = datetime(dt.year, fm, fd) |
| except ValueError: |
| continue |
| if abs((dt - f_dt).days) <= window: |
| return mult, label |
|
|
| return 1.0, None |
|
|
|
|
| def seasonal_effect(dt: datetime) -> tuple[float, Optional[str]]: |
| month = dt.month |
| if 6 <= month <= 9: |
| return 0.85, "monsoon" |
| elif month == 11: |
| return 1.35, "post-monsoon peak" |
| elif 10 <= month <= 12: |
| return 1.10, "harvest season" |
| elif month in (1, 2): |
| return 0.92, "winter slowdown" |
| return 1.0, None |
|
|
|
|
| |
| _weather_cache: dict[str, tuple[float, dict]] = {} |
| WEATHER_TTL = 300 |
|
|
|
|
| async def fetch_weather(city: str) -> dict: |
| now = datetime.now().timestamp() |
| if city in _weather_cache: |
| ts, cached = _weather_cache[city] |
| if now - ts < WEATHER_TTL: |
| return cached |
|
|
| if not OPENWEATHER_KEY: |
| return {} |
|
|
| try: |
| async with httpx.AsyncClient(timeout=5.0) as client: |
| r = await client.get( |
| "https://api.openweathermap.org/data/2.5/weather", |
| params={ |
| "q": f"{city},IN", |
| "appid": OPENWEATHER_KEY, |
| "units": "metric", |
| }, |
| ) |
| if r.status_code == 200: |
| raw = r.json() |
| data = { |
| "temp_c": raw["main"]["temp"], |
| "rain_mm": raw.get("rain", {}).get("3h", 0.0), |
| "condition": raw["weather"][0]["main"], |
| "description": raw["weather"][0]["description"], |
| "humidity": raw["main"]["humidity"], |
| } |
| _weather_cache[city] = (now, data) |
| return data |
| except Exception as e: |
| logger.warning(f"Weather fetch failed for {city}: {e}") |
| return {} |
|
|
|
|
| def weather_effect(weather: dict) -> tuple[float, Optional[str]]: |
| if not weather: |
| return 1.0, None |
|
|
| temp = weather.get("temp_c", 25) |
| rain = weather.get("rain_mm", 0.0) |
| cond = weather.get("condition", "") |
|
|
| if rain > 15 or cond == "Thunderstorm": |
| return 1.18, f"heavy rain ({rain:.0f} mm) β stockpiling spike" |
| elif rain > 5: |
| return 1.08, f"moderate rain ({rain:.0f} mm)" |
| elif temp > 42: |
| return 1.12, f"extreme heat ({temp:.0f}Β°C) β cooling goods spike" |
| elif temp < 5: |
| return 1.08, f"cold wave ({temp:.0f}Β°C) β heating goods spike" |
| elif cond == "Fog" and temp < 15: |
| return 0.93, f"dense fog ({temp:.0f}Β°C) β logistics delay expected" |
| return 1.0, None |
|
|
|
|
| |
| def load_model() -> bool: |
| global MODEL, CONFIG, DATASET_PARAMS |
| try: |
| from pytorch_forecasting import TemporalFusionTransformer |
|
|
| ckpt_path = "artifacts/tft_final.ckpt" |
| config_path = "artifacts/tft_config.json" |
|
|
| if not os.path.exists(ckpt_path): |
| logger.error(f"Checkpoint not found: {ckpt_path}") |
| return False |
|
|
| with open(config_path) as f: |
| CONFIG = json.load(f) |
|
|
| raw = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| hp = raw["hyper_parameters"] |
|
|
| |
| DATASET_PARAMS = hp.get("dataset_parameters") or raw.get("dataset_parameters") |
|
|
| |
| |
| |
| STRIP_KEYS = { |
| "monotone_constraints", |
| "mask_bias", |
| "reduce_on_plateau_reduction", |
| "reduce_on_plateau_min_lr", |
| "dataset_parameters", |
| } |
| for key in STRIP_KEYS: |
| if key in hp: |
| logger.info(f"Stripping incompatible hparam: {key}") |
| hp.pop(key) |
|
|
| MODEL = TemporalFusionTransformer(**hp) |
| MODEL.load_state_dict(raw["state_dict"]) |
| MODEL.eval() |
| logger.info(f"Model loaded β MAPE {CONFIG.get('overall_mape_pct')}%") |
| return True |
|
|
| except Exception as e: |
| logger.error(f"Model load failed: {e}") |
| return False |
|
|
|
|
| |
| def _make_history_df(node_name: str, end_date: datetime, horizon: int) -> pd.DataFrame: |
| """Generate 60-day encoder context + horizon decoder rows for TFT.""" |
| total = 60 + horizon |
| base_d = NODE_BASELINE.get(node_name, {"demand": 0.4})["demand"] |
| rows = [] |
| rng = np.random.default_rng(seed=abs(hash(node_name)) % (2**31)) |
|
|
| for i in range(total): |
| dt = end_date - timedelta(days=total - 1 - i) |
| month = dt.month |
| year = dt.year |
|
|
| dm, dd = DIWALI.get(year, (11, 1)) |
| diwali_dt = datetime(year, dm, dd) |
| is_diwali = 1.0 if abs((dt - diwali_dt).days) <= 7 else 0.0 |
| is_monsoon = 1.0 if 6 <= month <= 9 else 0.0 |
| is_harvest = 1.0 if 10 <= month <= 12 else 0.0 |
|
|
| seasonal = 1.0 |
| if is_monsoon: seasonal *= 0.85 |
| if is_harvest: seasonal *= 1.10 |
| if is_diwali: seasonal *= 1.35 |
|
|
| demand = float(np.clip(base_d * seasonal + rng.normal(0, 0.03), 0.05, 0.95)) |
| volatility = float(np.clip(0.20 + rng.normal(0, 0.04), 0.02, 0.80)) |
|
|
| rows.append({ |
| "node_name": node_name, |
| "time_idx": int((dt - ORIGIN_DATE).days), |
| "day_of_week": str(dt.weekday()), |
| "month_str": str(month), |
| "is_monsoon": is_monsoon, |
| "is_diwali_week": is_diwali, |
| "is_harvest_season": is_harvest, |
| "demand_index": demand, |
| "price_volatility": volatility, |
| }) |
|
|
| return pd.DataFrame(rows) |
|
|
|
|
| _inference_cache: dict[str, tuple[str, list, list, list]] = {} |
|
|
|
|
| def _sanitize_dataset_params(dp: dict) -> dict: |
| """Ensure all collection-type fields in dataset_parameters are not None.""" |
| list_fields = [ |
| "static_categoricals", "static_reals", |
| "time_varying_known_categoricals", "time_varying_known_reals", |
| "time_varying_unknown_categoricals", "time_varying_unknown_reals", |
| "variable_groups", |
| ] |
| dict_fields = ["lags", "constant_fill_strategy", "categorical_encoders", "scalers"] |
| patched = dict(dp) |
| for field in list_fields: |
| if patched.get(field) is None: |
| patched[field] = [] |
| for field in dict_fields: |
| if patched.get(field) is None: |
| patched[field] = {} |
| return patched |
|
|
|
|
| def run_tft_inference(node_name: str, horizon: int) -> tuple[list[int], list[int], list[int]]: |
| """Return (median, lower_q10, upper_q90) unit lists via real TFT forward pass.""" |
| from pytorch_forecasting import TimeSeriesDataSet |
|
|
| cache_key = f"{node_name}:{datetime.now().strftime('%Y-%m-%d')}:{horizon}" |
| if node_name in _inference_cache and _inference_cache[node_name][0] == cache_key: |
| _, med, lo, hi = _inference_cache[node_name] |
| return med, lo, hi |
|
|
| h = min(horizon, 30) |
| df = _make_history_df(node_name, datetime.now(), h) |
|
|
| dataset = TimeSeriesDataSet.from_parameters( |
| _sanitize_dataset_params(DATASET_PARAMS or {}), df, |
| predict=True, |
| stop_randomization=True, |
| ) |
| loader = dataset.to_dataloader(train=False, batch_size=1, num_workers=0) |
|
|
| with torch.no_grad(): |
| preds = MODEL.predict(loader, mode="quantiles") |
|
|
| q = preds[0].cpu().numpy() |
| scale = NODE_BASELINE.get(node_name, {"scale": 400})["scale"] |
|
|
| lo = [max(0, int(q[i, 0] * scale)) for i in range(h)] |
| med = [max(0, int(q[i, 1] * scale)) for i in range(h)] |
| hi = [max(0, int(q[i, 2] * scale)) for i in range(h)] |
|
|
| _inference_cache[node_name] = (cache_key, med, lo, hi) |
| return med, lo, hi |
|
|
|
|
| |
| def _fallback_series(node_name: str, horizon: int) -> tuple[list[int], list[int], list[int]]: |
| fb = FALLBACK_UNITS.get(node_name, {"predicted": 200, "lower": 174, "upper": 230}) |
| rng = np.random.default_rng(seed=42) |
| med, lo, hi = [], [], [] |
|
|
| for d in range(horizon): |
| dt = datetime.now() + timedelta(days=d + 1) |
| s_mult, _ = seasonal_effect(dt) |
| f_mult, _ = festival_effect(dt) |
| mult = s_mult * f_mult |
| noise = int(rng.normal(0, fb["predicted"] * 0.03)) |
| p = max(0, int(fb["predicted"] * mult) + noise) |
| med.append(p) |
| lo.append(max(0, int(p * 0.87))) |
| hi.append(int(p * 1.15)) |
|
|
| return med, lo, hi |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| logger.info("Loading TFT model...") |
| ok = load_model() |
| logger.info("Model ready" if ok else "Fallback mode active") |
| yield |
| logger.info("Shutting down") |
|
|
|
|
| app = FastAPI( |
| title = "SmartChain Forecasting API", |
| description = "TFT demand forecasting for Indian supply chain nodes. MAPE: 1.79%", |
| version = "2.0.0", |
| lifespan = lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| class ForecastRequest(BaseModel): |
| node_name: str |
| horizon_days: int = 30 |
|
|
|
|
| class ForecastPoint(BaseModel): |
| date: str |
| node_name: str |
| predicted_units: int |
| lower_bound: int |
| upper_bound: int |
| confidence_pct: float |
| reorder_alert: bool |
|
|
|
|
| class ForecastResponse(BaseModel): |
| model_config = {"protected_namespaces": ()} |
| node_name: str |
| horizon_days: int |
| forecasts: list[ForecastPoint] |
| model_mape_pct: Optional[float] |
| source: str |
| reorder_alert: bool |
| alert_reason: Optional[str] |
|
|
|
|
| class EnrichedForecastPoint(BaseModel): |
| date: str |
| node_name: str |
| predicted_units: int |
| lower_bound: int |
| upper_bound: int |
| confidence_pct: float |
| reorder_alert: bool |
| seasonal_factor: float |
| festival_factor: float |
| weather_factor: float |
| festival_name: Optional[str] |
| weather_reason: Optional[str] |
| season_label: Optional[str] |
|
|
|
|
| class WeatherSnapshot(BaseModel): |
| city: str |
| temp_c: Optional[float] |
| condition: Optional[str] |
| rain_mm: Optional[float] |
| description: Optional[str] |
|
|
|
|
| class EnrichedForecastResponse(BaseModel): |
| model_config = {"protected_namespaces": ()} |
| node_name: str |
| horizon_days: int |
| forecasts: list[EnrichedForecastPoint] |
| model_mape_pct: Optional[float] |
| source: str |
| weather: Optional[WeatherSnapshot] |
| reorder_alert: bool |
| alert_reason: Optional[str] |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str |
| model_loaded: bool |
| dataset_params: bool |
| overall_mape_pct: Optional[float] |
| nodes: list[str] |
| weather_enabled: bool |
|
|
|
|
| |
| def _build_forecasts( |
| node_name: str, |
| horizon: int, |
| weather: dict, |
| enriched: bool, |
| ) -> tuple[list, str, bool, Optional[str]]: |
| """ |
| Returns (points, source, reorder_alert, alert_reason). |
| points are EnrichedForecastPoint if enriched=True else ForecastPoint. |
| """ |
| source = "model" |
| med = lo = hi = None |
|
|
| if MODEL is not None and DATASET_PARAMS is not None: |
| try: |
| med, lo, hi = run_tft_inference(node_name, horizon) |
| except Exception as e: |
| logger.warning(f"TFT inference failed for {node_name}: {e} β fallback") |
| source = "fallback" |
| else: |
| source = "fallback" |
|
|
| if med is None: |
| med, lo, hi = _fallback_series(node_name, horizon) |
|
|
| safety = SAFETY_STOCK.get(node_name, 250) |
| w_mult, w_reason = weather_effect(weather) |
|
|
| points = [] |
| reorder_flag = False |
| alert_reason = None |
|
|
| for d in range(horizon): |
| dt = datetime.now() + timedelta(days=d + 1) |
|
|
| s_mult, s_label = seasonal_effect(dt) |
| f_mult, f_label = festival_effect(dt) |
|
|
| |
| combined = s_mult * f_mult * w_mult |
| p = max(0, int(med[d] * combined)) |
| p_lo = max(0, int(lo[d] * combined)) |
| p_hi = max(0, int(hi[d] * combined)) |
|
|
| conf = round(max(0.50, 0.97 - d * 0.004), 3) |
| alert = p_lo < safety |
|
|
| if alert and not reorder_flag: |
| reorder_flag = True |
| alert_reason = ( |
| f"Lower bound ({p_lo} units) on {dt.strftime('%Y-%m-%d')} " |
| f"is below safety stock ({safety} units)" |
| ) |
|
|
| if enriched: |
| points.append(EnrichedForecastPoint( |
| date = dt.strftime("%Y-%m-%d"), |
| node_name = node_name, |
| predicted_units = p, |
| lower_bound = p_lo, |
| upper_bound = p_hi, |
| confidence_pct = conf, |
| reorder_alert = alert, |
| seasonal_factor = round(s_mult, 3), |
| festival_factor = round(f_mult, 3), |
| weather_factor = round(w_mult, 3), |
| festival_name = f_label, |
| weather_reason = w_reason, |
| season_label = s_label, |
| )) |
| else: |
| points.append(ForecastPoint( |
| date = dt.strftime("%Y-%m-%d"), |
| node_name = node_name, |
| predicted_units = p, |
| lower_bound = p_lo, |
| upper_bound = p_hi, |
| confidence_pct = conf, |
| reorder_alert = alert, |
| )) |
|
|
| return points, source, reorder_flag, alert_reason |
|
|
|
|
| |
| @app.get("/") |
| async def root(): |
| return { |
| "name": "SmartChain Forecasting API", |
| "version": "2.0.0", |
| "status": "running", |
| "docs": "/docs", |
| "mape": "1.79%", |
| "endpoints": ["/health", "/nodes", "/forecast-demand", "/forecast-enriched"], |
| } |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health(): |
| return HealthResponse( |
| status = "ok", |
| model_loaded = MODEL is not None, |
| dataset_params = DATASET_PARAMS is not None, |
| overall_mape_pct = CONFIG.get("overall_mape_pct") if CONFIG else 1.79, |
| nodes = NODES, |
| weather_enabled = bool(OPENWEATHER_KEY), |
| ) |
|
|
|
|
| @app.get("/nodes") |
| async def get_nodes(): |
| return {"nodes": NODES, "safety_stock": SAFETY_STOCK} |
|
|
|
|
| @app.post("/forecast-demand", response_model=ForecastResponse) |
| async def forecast_demand(req: ForecastRequest): |
| _validate(req) |
|
|
| city = NODE_CITIES.get(req.node_name, "Mumbai") |
| weather = await fetch_weather(city) |
|
|
| points, source, reorder_alert, alert_reason = _build_forecasts( |
| req.node_name, req.horizon_days, weather, enriched=False |
| ) |
|
|
| return ForecastResponse( |
| node_name = req.node_name, |
| horizon_days = req.horizon_days, |
| forecasts = points, |
| model_mape_pct = CONFIG.get("overall_mape_pct") if CONFIG else 1.79, |
| source = source, |
| reorder_alert = reorder_alert, |
| alert_reason = alert_reason, |
| ) |
|
|
|
|
| @app.post("/forecast-enriched", response_model=EnrichedForecastResponse) |
| async def forecast_enriched(req: ForecastRequest): |
| _validate(req) |
|
|
| city = NODE_CITIES.get(req.node_name, "Mumbai") |
| weather = await fetch_weather(city) |
|
|
| points, source, reorder_alert, alert_reason = _build_forecasts( |
| req.node_name, req.horizon_days, weather, enriched=True |
| ) |
|
|
| w_snap = None |
| if weather: |
| w_snap = WeatherSnapshot( |
| city = city, |
| temp_c = weather.get("temp_c"), |
| condition = weather.get("condition"), |
| rain_mm = weather.get("rain_mm"), |
| description = weather.get("description"), |
| ) |
|
|
| return EnrichedForecastResponse( |
| node_name = req.node_name, |
| horizon_days = req.horizon_days, |
| forecasts = points, |
| model_mape_pct = CONFIG.get("overall_mape_pct") if CONFIG else 1.79, |
| source = source, |
| weather = w_snap, |
| reorder_alert = reorder_alert, |
| alert_reason = alert_reason, |
| ) |
|
|
|
|
| def _validate(req: ForecastRequest): |
| if req.node_name not in NODES: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown node '{req.node_name}'. Valid: {NODES}", |
| ) |
| if not (1 <= req.horizon_days <= 30): |
| raise HTTPException( |
| status_code=400, |
| detail="horizon_days must be 1β30", |
| ) |
|
|