from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch import torch.nn as nn import numpy as np import os from typing import List, Optional app = FastAPI(title="Indian Supply Chain MARL API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) N_AGENTS = 15 OBS_DIM = 8 N_ACTIONS = 5 HIDDEN_DIM = 64 STATE_DIM = N_AGENTS * 6 # state_feat_dim=6 per agent ACTION_LABELS = ["hold", "optimize", "redirect", "scale_back", "emergency"] ACTION_DESCRIPTIONS = [ "Maintain current allocation — no change needed", "Minor load rebalancing — reduce throughput by 10%", "Reroute shipments — major load reduction by 20%", "Scale back capacity — reduce utilization by 15%", "Emergency response — redistribute to neighboring nodes immediately", ] class QAgent(nn.Module): def __init__(self, obs_dim, n_actions, hidden_dim=64): super().__init__() self.gru = nn.GRU(obs_dim, hidden_dim, batch_first=True) self.fc1 = nn.Linear(hidden_dim, 32) self.fc2 = nn.Linear(32, n_actions) def forward(self, x, hidden): out, h = self.gru(x, hidden) out = torch.relu(self.fc1(out)) return self.fc2(out), h device = torch.device("cpu") agent = QAgent(OBS_DIM, N_ACTIONS, HIDDEN_DIM).to(device) weights_path = "qmix_agent_weights.pth" if os.path.exists(weights_path): agent.load_state_dict(torch.load(weights_path, map_location=device)) agent.eval() class InferenceRequest(BaseModel): obs_n: List[List[float]] # [n_agents, obs_dim] internal_state: Optional[List[List[float]]] = None # [n_agents, 6] class AllocationRequest(BaseModel): trucks: List[dict] warehouses: List[dict] carriers: List[dict] def _resource_obs( capacity_ratio: float, availability: float, disruption: float, utilization: float, neighbor_avg: float = 0.75, ) -> List[float]: """Build an 8-dim observation vector matching the training environment.""" # Features: [capacity_ratio, availability, disruption, disruption_flag, # disruption_severity, utilization, neighbor_avg, time_feature] return [ float(np.clip(capacity_ratio, 0, 1)), float(np.clip(availability, 0, 1)), float(np.clip(disruption, 0, 1)), 1.0 if disruption > 0.5 else 0.0, float(np.clip(disruption, 0, 1)), float(np.clip(utilization, 0, 1)), float(np.clip(neighbor_avg, 0, 1)), 0.5, # normalised time feature (mid-day default) ] def build_obs_from_resources(trucks, warehouses, carriers) -> List[List[float]]: """Map real DB resources to 15-agent observation vectors. Columns come from the actual Neon DB tables (SELECT *): trucks: vehicle_number, capacity_kg, fuel_level, status, model, ... warehouses: name, capacity_sqft, utilization_percent, status, ... carriers: name, rating, on_time_delivery_percent, status, ... """ obs = [] for t in trucks[:5]: # capacity_kg up to ~30000 kg; normalise to [0,1] cap_kg = float(t.get("capacity_kg") or t.get("capacity") or 10000) cap_ratio = min(cap_kg / 30000.0, 1.0) fuel = float(t.get("fuel_level") or 100) / 100.0 status = str(t.get("status", "available")) availability = {"available": 1.0, "in_transit": 0.7, "maintenance": 0.2}.get(status, 0.5) # low fuel or maintenance → disruption signal disruption = max(0.0, 1.0 - fuel) * 0.5 + (0.5 if status == "maintenance" else 0.0) obs.append(_resource_obs(cap_ratio, availability, min(disruption, 1.0), 1.0 - availability)) for w in warehouses[:5]: utilization = float(w.get("utilization_percent") or 0) / 100.0 status = str(w.get("status", "operational")) availability = {"operational": 1.0, "active": 1.0, "maintenance": 0.3, "inactive": 0.1}.get(status, 0.7) disruption = max(0.0, utilization - 0.80) # signal only when >80% full obs.append(_resource_obs(1.0 - utilization, availability, disruption, utilization)) for c in carriers[:5]: rating = float(c.get("rating") or 4.0) rating_norm = min(rating / 5.0, 1.0) on_time = float(c.get("on_time_delivery_percent") or 85) / 100.0 status = str(c.get("status", "active")) availability = 1.0 if status == "active" else 0.3 disruption = 1.0 - on_time obs.append(_resource_obs(rating_norm, availability, disruption, 1.0 - on_time)) # Pad to exactly 15 agents with neutral observations while len(obs) < N_AGENTS: obs.append(_resource_obs(0.8, 0.8, 0.05, 0.4)) return obs[:N_AGENTS] @app.post("/predict") def predict(data: InferenceRequest): try: obs = np.array(data.obs_n, dtype=np.float32) if obs.shape != (N_AGENTS, OBS_DIM): raise ValueError(f"obs_n must be [{N_AGENTS}, {OBS_DIM}], got {obs.shape}") obs_tensor = torch.FloatTensor(obs).unsqueeze(1) # [15, 1, 8] h_state = torch.zeros(1, N_AGENTS, HIDDEN_DIM) with torch.no_grad(): q_vals, _ = agent(obs_tensor, h_state) actions = [int(torch.argmax(q_vals[i, 0]).item()) for i in range(N_AGENTS)] return { "actions": actions, "action_labels": [ACTION_LABELS[a] for a in actions], } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/allocate") def allocate(data: AllocationRequest): """High-level endpoint: accepts real resource dicts, returns labelled recommendations.""" try: obs_n = build_obs_from_resources(data.trucks, data.warehouses, data.carriers) obs_tensor = torch.FloatTensor(np.array(obs_n, dtype=np.float32)).unsqueeze(1) h_state = torch.zeros(1, N_AGENTS, HIDDEN_DIM) with torch.no_grad(): q_vals, _ = agent(obs_tensor, h_state) actions = [int(torch.argmax(q_vals[i, 0]).item()) for i in range(N_AGENTS)] # Build resource list in the same order used for observations resource_list: List[dict] = ( [{"type": "truck", **t} for t in data.trucks[:5]] + [{"type": "warehouse", **w} for w in data.warehouses[:5]] + [{"type": "carrier", **c} for c in data.carriers[:5]] ) while len(resource_list) < N_AGENTS: resource_list.append({"type": "virtual"}) recommendations = [] for i, (res, act) in enumerate(zip(resource_list[:N_AGENTS], actions)): if res["type"] == "virtual": continue name = ( res.get("name") or res.get("vehicle_number") or res.get("driver_name") or res.get("driver") or f"Resource {i}" ) recommendations.append({ "agent_id": i, "resource_type": res["type"], "resource_id": str(res.get("id", i)), "resource_name": name, "action": ACTION_LABELS[act], "action_code": act, "description": ACTION_DESCRIPTIONS[act], "priority": "high" if act in [3, 4] else "medium" if act in [1, 2] else "low", "obs": obs_n[i], }) delivery_estimate = float( np.mean([obs_n[i][0] * (1.0 - obs_n[i][2]) for i in range(N_AGENTS)]) ) return { "recommendations": recommendations, "summary": { "total_agents": len(recommendations), "actions_needed": sum(1 for r in recommendations if r["action"] != "hold"), "high_priority": sum(1 for r in recommendations if r["priority"] == "high"), "delivery_rate_estimate": round(delivery_estimate, 3), }, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") def health(): weights_loaded = os.path.exists(weights_path) return { "status": "active", "model": "QMIX Multi-Agent RL Resource Allocator", "n_agents": N_AGENTS, "obs_dim": OBS_DIM, "n_actions": N_ACTIONS, "weights_loaded": weights_loaded, }