| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import os |
| from typing import List, Optional |
|
|
| app = FastAPI(title="Indian Supply Chain MARL API") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| N_AGENTS = 15 |
| OBS_DIM = 8 |
| N_ACTIONS = 5 |
| HIDDEN_DIM = 64 |
| STATE_DIM = N_AGENTS * 6 |
|
|
| ACTION_LABELS = ["hold", "optimize", "redirect", "scale_back", "emergency"] |
| ACTION_DESCRIPTIONS = [ |
| "Maintain current allocation — no change needed", |
| "Minor load rebalancing — reduce throughput by 10%", |
| "Reroute shipments — major load reduction by 20%", |
| "Scale back capacity — reduce utilization by 15%", |
| "Emergency response — redistribute to neighboring nodes immediately", |
| ] |
|
|
|
|
| class QAgent(nn.Module): |
| def __init__(self, obs_dim, n_actions, hidden_dim=64): |
| super().__init__() |
| self.gru = nn.GRU(obs_dim, hidden_dim, batch_first=True) |
| self.fc1 = nn.Linear(hidden_dim, 32) |
| self.fc2 = nn.Linear(32, n_actions) |
|
|
| def forward(self, x, hidden): |
| out, h = self.gru(x, hidden) |
| out = torch.relu(self.fc1(out)) |
| return self.fc2(out), h |
|
|
|
|
| device = torch.device("cpu") |
| agent = QAgent(OBS_DIM, N_ACTIONS, HIDDEN_DIM).to(device) |
|
|
| weights_path = "qmix_agent_weights.pth" |
| if os.path.exists(weights_path): |
| agent.load_state_dict(torch.load(weights_path, map_location=device)) |
| agent.eval() |
|
|
|
|
| class InferenceRequest(BaseModel): |
| obs_n: List[List[float]] |
| internal_state: Optional[List[List[float]]] = None |
|
|
|
|
| class AllocationRequest(BaseModel): |
| trucks: List[dict] |
| warehouses: List[dict] |
| carriers: List[dict] |
|
|
|
|
| def _resource_obs( |
| capacity_ratio: float, |
| availability: float, |
| disruption: float, |
| utilization: float, |
| neighbor_avg: float = 0.75, |
| ) -> List[float]: |
| """Build an 8-dim observation vector matching the training environment.""" |
| |
| |
| return [ |
| float(np.clip(capacity_ratio, 0, 1)), |
| float(np.clip(availability, 0, 1)), |
| float(np.clip(disruption, 0, 1)), |
| 1.0 if disruption > 0.5 else 0.0, |
| float(np.clip(disruption, 0, 1)), |
| float(np.clip(utilization, 0, 1)), |
| float(np.clip(neighbor_avg, 0, 1)), |
| 0.5, |
| ] |
|
|
|
|
| def build_obs_from_resources(trucks, warehouses, carriers) -> List[List[float]]: |
| """Map real DB resources to 15-agent observation vectors. |
| |
| Columns come from the actual Neon DB tables (SELECT *): |
| trucks: vehicle_number, capacity_kg, fuel_level, status, model, ... |
| warehouses: name, capacity_sqft, utilization_percent, status, ... |
| carriers: name, rating, on_time_delivery_percent, status, ... |
| """ |
| obs = [] |
|
|
| for t in trucks[:5]: |
| |
| cap_kg = float(t.get("capacity_kg") or t.get("capacity") or 10000) |
| cap_ratio = min(cap_kg / 30000.0, 1.0) |
| fuel = float(t.get("fuel_level") or 100) / 100.0 |
| status = str(t.get("status", "available")) |
| availability = {"available": 1.0, "in_transit": 0.7, "maintenance": 0.2}.get(status, 0.5) |
| |
| disruption = max(0.0, 1.0 - fuel) * 0.5 + (0.5 if status == "maintenance" else 0.0) |
| obs.append(_resource_obs(cap_ratio, availability, min(disruption, 1.0), 1.0 - availability)) |
|
|
| for w in warehouses[:5]: |
| utilization = float(w.get("utilization_percent") or 0) / 100.0 |
| status = str(w.get("status", "operational")) |
| availability = {"operational": 1.0, "active": 1.0, "maintenance": 0.3, "inactive": 0.1}.get(status, 0.7) |
| disruption = max(0.0, utilization - 0.80) |
| obs.append(_resource_obs(1.0 - utilization, availability, disruption, utilization)) |
|
|
| for c in carriers[:5]: |
| rating = float(c.get("rating") or 4.0) |
| rating_norm = min(rating / 5.0, 1.0) |
| on_time = float(c.get("on_time_delivery_percent") or 85) / 100.0 |
| status = str(c.get("status", "active")) |
| availability = 1.0 if status == "active" else 0.3 |
| disruption = 1.0 - on_time |
| obs.append(_resource_obs(rating_norm, availability, disruption, 1.0 - on_time)) |
|
|
| |
| while len(obs) < N_AGENTS: |
| obs.append(_resource_obs(0.8, 0.8, 0.05, 0.4)) |
|
|
| return obs[:N_AGENTS] |
|
|
|
|
| @app.post("/predict") |
| def predict(data: InferenceRequest): |
| try: |
| obs = np.array(data.obs_n, dtype=np.float32) |
| if obs.shape != (N_AGENTS, OBS_DIM): |
| raise ValueError(f"obs_n must be [{N_AGENTS}, {OBS_DIM}], got {obs.shape}") |
| obs_tensor = torch.FloatTensor(obs).unsqueeze(1) |
| h_state = torch.zeros(1, N_AGENTS, HIDDEN_DIM) |
| with torch.no_grad(): |
| q_vals, _ = agent(obs_tensor, h_state) |
| actions = [int(torch.argmax(q_vals[i, 0]).item()) for i in range(N_AGENTS)] |
| return { |
| "actions": actions, |
| "action_labels": [ACTION_LABELS[a] for a in actions], |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/allocate") |
| def allocate(data: AllocationRequest): |
| """High-level endpoint: accepts real resource dicts, returns labelled recommendations.""" |
| try: |
| obs_n = build_obs_from_resources(data.trucks, data.warehouses, data.carriers) |
| obs_tensor = torch.FloatTensor(np.array(obs_n, dtype=np.float32)).unsqueeze(1) |
| h_state = torch.zeros(1, N_AGENTS, HIDDEN_DIM) |
| with torch.no_grad(): |
| q_vals, _ = agent(obs_tensor, h_state) |
| actions = [int(torch.argmax(q_vals[i, 0]).item()) for i in range(N_AGENTS)] |
|
|
| |
| resource_list: List[dict] = ( |
| [{"type": "truck", **t} for t in data.trucks[:5]] |
| + [{"type": "warehouse", **w} for w in data.warehouses[:5]] |
| + [{"type": "carrier", **c} for c in data.carriers[:5]] |
| ) |
| while len(resource_list) < N_AGENTS: |
| resource_list.append({"type": "virtual"}) |
|
|
| recommendations = [] |
| for i, (res, act) in enumerate(zip(resource_list[:N_AGENTS], actions)): |
| if res["type"] == "virtual": |
| continue |
| name = ( |
| res.get("name") |
| or res.get("vehicle_number") |
| or res.get("driver_name") |
| or res.get("driver") |
| or f"Resource {i}" |
| ) |
| recommendations.append({ |
| "agent_id": i, |
| "resource_type": res["type"], |
| "resource_id": str(res.get("id", i)), |
| "resource_name": name, |
| "action": ACTION_LABELS[act], |
| "action_code": act, |
| "description": ACTION_DESCRIPTIONS[act], |
| "priority": "high" if act in [3, 4] else "medium" if act in [1, 2] else "low", |
| "obs": obs_n[i], |
| }) |
|
|
| delivery_estimate = float( |
| np.mean([obs_n[i][0] * (1.0 - obs_n[i][2]) for i in range(N_AGENTS)]) |
| ) |
|
|
| return { |
| "recommendations": recommendations, |
| "summary": { |
| "total_agents": len(recommendations), |
| "actions_needed": sum(1 for r in recommendations if r["action"] != "hold"), |
| "high_priority": sum(1 for r in recommendations if r["priority"] == "high"), |
| "delivery_rate_estimate": round(delivery_estimate, 3), |
| }, |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/") |
| def health(): |
| weights_loaded = os.path.exists(weights_path) |
| return { |
| "status": "active", |
| "model": "QMIX Multi-Agent RL Resource Allocator", |
| "n_agents": N_AGENTS, |
| "obs_dim": OBS_DIM, |
| "n_actions": N_ACTIONS, |
| "weights_loaded": weights_loaded, |
| } |
|
|