virustechhacks's picture
Upload app.py with huggingface_hub
3b9f511 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn as nn
import numpy as np
import os
from typing import List, Optional
app = FastAPI(title="Indian Supply Chain MARL API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
N_AGENTS = 15
OBS_DIM = 8
N_ACTIONS = 5
HIDDEN_DIM = 64
STATE_DIM = N_AGENTS * 6 # state_feat_dim=6 per agent
ACTION_LABELS = ["hold", "optimize", "redirect", "scale_back", "emergency"]
ACTION_DESCRIPTIONS = [
"Maintain current allocation — no change needed",
"Minor load rebalancing — reduce throughput by 10%",
"Reroute shipments — major load reduction by 20%",
"Scale back capacity — reduce utilization by 15%",
"Emergency response — redistribute to neighboring nodes immediately",
]
class QAgent(nn.Module):
def __init__(self, obs_dim, n_actions, hidden_dim=64):
super().__init__()
self.gru = nn.GRU(obs_dim, hidden_dim, batch_first=True)
self.fc1 = nn.Linear(hidden_dim, 32)
self.fc2 = nn.Linear(32, n_actions)
def forward(self, x, hidden):
out, h = self.gru(x, hidden)
out = torch.relu(self.fc1(out))
return self.fc2(out), h
device = torch.device("cpu")
agent = QAgent(OBS_DIM, N_ACTIONS, HIDDEN_DIM).to(device)
weights_path = "qmix_agent_weights.pth"
if os.path.exists(weights_path):
agent.load_state_dict(torch.load(weights_path, map_location=device))
agent.eval()
class InferenceRequest(BaseModel):
obs_n: List[List[float]] # [n_agents, obs_dim]
internal_state: Optional[List[List[float]]] = None # [n_agents, 6]
class AllocationRequest(BaseModel):
trucks: List[dict]
warehouses: List[dict]
carriers: List[dict]
def _resource_obs(
capacity_ratio: float,
availability: float,
disruption: float,
utilization: float,
neighbor_avg: float = 0.75,
) -> List[float]:
"""Build an 8-dim observation vector matching the training environment."""
# Features: [capacity_ratio, availability, disruption, disruption_flag,
# disruption_severity, utilization, neighbor_avg, time_feature]
return [
float(np.clip(capacity_ratio, 0, 1)),
float(np.clip(availability, 0, 1)),
float(np.clip(disruption, 0, 1)),
1.0 if disruption > 0.5 else 0.0,
float(np.clip(disruption, 0, 1)),
float(np.clip(utilization, 0, 1)),
float(np.clip(neighbor_avg, 0, 1)),
0.5, # normalised time feature (mid-day default)
]
def build_obs_from_resources(trucks, warehouses, carriers) -> List[List[float]]:
"""Map real DB resources to 15-agent observation vectors.
Columns come from the actual Neon DB tables (SELECT *):
trucks: vehicle_number, capacity_kg, fuel_level, status, model, ...
warehouses: name, capacity_sqft, utilization_percent, status, ...
carriers: name, rating, on_time_delivery_percent, status, ...
"""
obs = []
for t in trucks[:5]:
# capacity_kg up to ~30000 kg; normalise to [0,1]
cap_kg = float(t.get("capacity_kg") or t.get("capacity") or 10000)
cap_ratio = min(cap_kg / 30000.0, 1.0)
fuel = float(t.get("fuel_level") or 100) / 100.0
status = str(t.get("status", "available"))
availability = {"available": 1.0, "in_transit": 0.7, "maintenance": 0.2}.get(status, 0.5)
# low fuel or maintenance → disruption signal
disruption = max(0.0, 1.0 - fuel) * 0.5 + (0.5 if status == "maintenance" else 0.0)
obs.append(_resource_obs(cap_ratio, availability, min(disruption, 1.0), 1.0 - availability))
for w in warehouses[:5]:
utilization = float(w.get("utilization_percent") or 0) / 100.0
status = str(w.get("status", "operational"))
availability = {"operational": 1.0, "active": 1.0, "maintenance": 0.3, "inactive": 0.1}.get(status, 0.7)
disruption = max(0.0, utilization - 0.80) # signal only when >80% full
obs.append(_resource_obs(1.0 - utilization, availability, disruption, utilization))
for c in carriers[:5]:
rating = float(c.get("rating") or 4.0)
rating_norm = min(rating / 5.0, 1.0)
on_time = float(c.get("on_time_delivery_percent") or 85) / 100.0
status = str(c.get("status", "active"))
availability = 1.0 if status == "active" else 0.3
disruption = 1.0 - on_time
obs.append(_resource_obs(rating_norm, availability, disruption, 1.0 - on_time))
# Pad to exactly 15 agents with neutral observations
while len(obs) < N_AGENTS:
obs.append(_resource_obs(0.8, 0.8, 0.05, 0.4))
return obs[:N_AGENTS]
@app.post("/predict")
def predict(data: InferenceRequest):
try:
obs = np.array(data.obs_n, dtype=np.float32)
if obs.shape != (N_AGENTS, OBS_DIM):
raise ValueError(f"obs_n must be [{N_AGENTS}, {OBS_DIM}], got {obs.shape}")
obs_tensor = torch.FloatTensor(obs).unsqueeze(1) # [15, 1, 8]
h_state = torch.zeros(1, N_AGENTS, HIDDEN_DIM)
with torch.no_grad():
q_vals, _ = agent(obs_tensor, h_state)
actions = [int(torch.argmax(q_vals[i, 0]).item()) for i in range(N_AGENTS)]
return {
"actions": actions,
"action_labels": [ACTION_LABELS[a] for a in actions],
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/allocate")
def allocate(data: AllocationRequest):
"""High-level endpoint: accepts real resource dicts, returns labelled recommendations."""
try:
obs_n = build_obs_from_resources(data.trucks, data.warehouses, data.carriers)
obs_tensor = torch.FloatTensor(np.array(obs_n, dtype=np.float32)).unsqueeze(1)
h_state = torch.zeros(1, N_AGENTS, HIDDEN_DIM)
with torch.no_grad():
q_vals, _ = agent(obs_tensor, h_state)
actions = [int(torch.argmax(q_vals[i, 0]).item()) for i in range(N_AGENTS)]
# Build resource list in the same order used for observations
resource_list: List[dict] = (
[{"type": "truck", **t} for t in data.trucks[:5]]
+ [{"type": "warehouse", **w} for w in data.warehouses[:5]]
+ [{"type": "carrier", **c} for c in data.carriers[:5]]
)
while len(resource_list) < N_AGENTS:
resource_list.append({"type": "virtual"})
recommendations = []
for i, (res, act) in enumerate(zip(resource_list[:N_AGENTS], actions)):
if res["type"] == "virtual":
continue
name = (
res.get("name")
or res.get("vehicle_number")
or res.get("driver_name")
or res.get("driver")
or f"Resource {i}"
)
recommendations.append({
"agent_id": i,
"resource_type": res["type"],
"resource_id": str(res.get("id", i)),
"resource_name": name,
"action": ACTION_LABELS[act],
"action_code": act,
"description": ACTION_DESCRIPTIONS[act],
"priority": "high" if act in [3, 4] else "medium" if act in [1, 2] else "low",
"obs": obs_n[i],
})
delivery_estimate = float(
np.mean([obs_n[i][0] * (1.0 - obs_n[i][2]) for i in range(N_AGENTS)])
)
return {
"recommendations": recommendations,
"summary": {
"total_agents": len(recommendations),
"actions_needed": sum(1 for r in recommendations if r["action"] != "hold"),
"high_priority": sum(1 for r in recommendations if r["priority"] == "high"),
"delivery_rate_estimate": round(delivery_estimate, 3),
},
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def health():
weights_loaded = os.path.exists(weights_path)
return {
"status": "active",
"model": "QMIX Multi-Agent RL Resource Allocator",
"n_agents": N_AGENTS,
"obs_dim": OBS_DIM,
"n_actions": N_ACTIONS,
"weights_loaded": weights_loaded,
}