Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from scipy.stats import gaussian_kde
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# 1. Hamiltonian Optimal Control Function
|
| 9 |
+
def hamilton_optimal_control(x, alpha_k, z, local_weight, pos_idx, d_sq):
|
| 10 |
+
if len(pos_idx) == 0: return None, 0
|
| 11 |
+
w_sub = local_weight[pos_idx]
|
| 12 |
+
d_wnE = d_sq / (w_sub**2 + 1e-12)
|
| 13 |
+
idx_sorted = np.argsort(d_wnE)
|
| 14 |
+
sorted_pos_idx = pos_idx[idx_sorted]
|
| 15 |
+
sorted_w = local_weight[sorted_pos_idx]
|
| 16 |
+
cumsum_w = np.cumsum(sorted_w)
|
| 17 |
+
cutoff = np.searchsorted(cumsum_w, alpha_k)
|
| 18 |
+
if cutoff >= len(sorted_pos_idx): cutoff = len(sorted_pos_idx) - 1
|
| 19 |
+
final_idx = sorted_pos_idx[:cutoff+1]
|
| 20 |
+
final_w = sorted_w[:cutoff+1].copy()
|
| 21 |
+
if cutoff > 0: final_w[-1] = alpha_k - cumsum_w[cutoff-1]
|
| 22 |
+
else: final_w[-1] = alpha_k
|
| 23 |
+
y_loc_pos = z[final_idx]
|
| 24 |
+
y_loc_wgt = final_w.reshape(-1, 1)
|
| 25 |
+
gamma = np.sum(y_loc_wgt)
|
| 26 |
+
xg = np.sum(y_loc_wgt * y_loc_pos, axis=0) / (gamma + 1e-12)
|
| 27 |
+
return xg, gamma
|
| 28 |
+
|
| 29 |
+
# 2. Local Weight Update Function
|
| 30 |
+
def update_weight_local(x, z, local_weight, alpha_k):
|
| 31 |
+
pos_idx = np.where(local_weight > 1e-15)[0]
|
| 32 |
+
if len(pos_idx) == 0: return local_weight, np.zeros_like(local_weight), pos_idx, np.array([])
|
| 33 |
+
dist_sq = np.sum((z[pos_idx] - x)**2, axis=1)
|
| 34 |
+
idx_sort = np.argsort(dist_sq)
|
| 35 |
+
sorted_idx = pos_idx[idx_sort]
|
| 36 |
+
sorted_w = local_weight[sorted_idx]
|
| 37 |
+
cumsum_w = np.cumsum(sorted_w)
|
| 38 |
+
cutoff = np.searchsorted(cumsum_w, alpha_k)
|
| 39 |
+
weight_dist = np.zeros_like(local_weight)
|
| 40 |
+
if cutoff < len(sorted_idx):
|
| 41 |
+
take_idx = sorted_idx[:cutoff]
|
| 42 |
+
weight_dist[take_idx] = local_weight[take_idx]; local_weight[take_idx] = 0
|
| 43 |
+
remainder = alpha_k - (cumsum_w[cutoff-1] if cutoff > 0 else 0)
|
| 44 |
+
weight_dist[sorted_idx[cutoff]] = remainder; local_weight[sorted_idx[cutoff]] -= remainder
|
| 45 |
+
else:
|
| 46 |
+
weight_dist[sorted_idx] = local_weight[sorted_idx]; local_weight[sorted_idx] = 0
|
| 47 |
+
local_weight[local_weight < 1e-12] = 0
|
| 48 |
+
new_pos_idx = np.where(local_weight > 1e-15)[0]
|
| 49 |
+
new_dist_sq = np.sum((z[new_pos_idx] - x)**2, axis=1) if len(new_pos_idx) > 0 else np.array([])
|
| 50 |
+
return local_weight, weight_dist, new_pos_idx, new_dist_sq
|
| 51 |
+
|
| 52 |
+
# 3. Main Simulation Logic
|
| 53 |
+
def run_simulation(num_agents, battery_life):
|
| 54 |
+
no_of_agents = int(num_agents)
|
| 55 |
+
bat_life = int(battery_life)
|
| 56 |
+
T, g, comm_range = 0.1, 9.81, 15.0
|
| 57 |
+
alpha_k = 1.0 / bat_life
|
| 58 |
+
frames = []
|
| 59 |
+
|
| 60 |
+
Ad_sub = np.array([[1, T, 0, 0], [0, 0.7, -T*g, 0], [0, 0, 1, T], [0, 0, 0, 0.7]])
|
| 61 |
+
Ad = np.zeros((8, 8)); Ad[:4, :4] = Ad_sub; Ad[4:, 4:] = Ad_sub
|
| 62 |
+
Bd = np.zeros((8, 2)); Bd[3, 0] = T/0.1; Bd[7, 1] = T/0.1
|
| 63 |
+
CTC = np.zeros((8, 8)); CTC[0,0] = 1.0; CTC[4,4] = 1.0
|
| 64 |
+
Q_base, R_block = np.eye(8) * 0.1, np.eye(2) * 10.0
|
| 65 |
+
hor_leng = 5
|
| 66 |
+
E12 = np.zeros((40, 40))
|
| 67 |
+
for i in range(hor_leng):
|
| 68 |
+
E12[i*8:(i+1)*8, i*8:(i+1)*8] = -np.eye(8)
|
| 69 |
+
if i < hor_leng - 1: E12[i*8:(i+1)*8, (i+1)*8:(i+2)*8] = Ad.T
|
| 70 |
+
E12_inv = np.linalg.inv(E12)
|
| 71 |
+
E23 = np.kron(np.eye(hor_leng), Bd)
|
| 72 |
+
E33 = np.kron(np.eye(hor_leng), R_block)
|
| 73 |
+
E12_inv_E23 = E12_inv @ E23
|
| 74 |
+
E23T_E12inv = E23.T @ E12_inv
|
| 75 |
+
|
| 76 |
+
y_ref = np.vstack([np.random.multivariate_normal(np.random.uniform(20, 80, 2), np.eye(2)*15, 200) for _ in range(8)])
|
| 77 |
+
xi, yi = np.mgrid[0:100:50j, 0:100:50j]
|
| 78 |
+
positions = np.vstack([xi.ravel(), yi.ravel()])
|
| 79 |
+
kde = gaussian_kde(y_ref.T)
|
| 80 |
+
zi = np.reshape(kde(positions).T, xi.shape)
|
| 81 |
+
|
| 82 |
+
agent_betas = [np.ones(len(y_ref)) / len(y_ref) for _ in range(no_of_agents)]
|
| 83 |
+
agent_wgt_history = np.zeros((no_of_agents, no_of_agents, len(y_ref)))
|
| 84 |
+
comm_time, agents = np.zeros((no_of_agents, no_of_agents)), np.zeros((no_of_agents, 8))
|
| 85 |
+
trajectories = [[] for _ in range(no_of_agents)]
|
| 86 |
+
|
| 87 |
+
for n in range(no_of_agents):
|
| 88 |
+
agents[n, [0, 4]] = np.random.uniform(10, 90, 2)
|
| 89 |
+
|
| 90 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 91 |
+
max_vel = 0.8
|
| 92 |
+
|
| 93 |
+
for t in range(1, 1001):
|
| 94 |
+
for n in range(no_of_agents):
|
| 95 |
+
trajectories[n].append(agents[n, [0, 4]].copy())
|
| 96 |
+
agent_betas[n], delta_w, pos_idx, d_sq = update_weight_local(agents[n, [0, 4]], y_ref, agent_betas[n], alpha_k)
|
| 97 |
+
agent_wgt_history[n, n] += delta_w
|
| 98 |
+
comm_time[n, n] = t
|
| 99 |
+
xg, gamma = hamilton_optimal_control(agents[n, [0, 4]], alpha_k, y_ref, agent_betas[n], pos_idx, d_sq)
|
| 100 |
+
if xg is not None and gamma > 1e-12:
|
| 101 |
+
Q_track = (gamma * 250.0) * CTC + Q_base
|
| 102 |
+
xg_full = np.zeros(8); xg_full[0], xg_full[4] = xg[0], xg[1]
|
| 103 |
+
F1 = np.tile(Q_track @ xg_full, hor_leng); F2 = np.zeros(40); F2[:8] = -Ad @ agents[n]
|
| 104 |
+
M = E12_inv_E23
|
| 105 |
+
E11_M = np.zeros_like(M)
|
| 106 |
+
for i in range(hor_leng): E11_M[i*8:(i+1)*8, :] = Q_track @ M[i*8:(i+1)*8, :]
|
| 107 |
+
lhs = E33 + M.T @ E11_M
|
| 108 |
+
temp_F2_trans = E12_inv.T @ F2
|
| 109 |
+
E11_temp_F2 = np.zeros(40)
|
| 110 |
+
for i in range(hor_leng): E11_temp_F2[i*8:(i+1)*8] = Q_track @ temp_F2_trans[i*8:(i+1)*8]
|
| 111 |
+
rhs = -(E23T_E12inv @ F1) + E23.T @ (E12_inv @ E11_temp_F2)
|
| 112 |
+
u_star = np.linalg.solve(lhs, rhs)[:2]
|
| 113 |
+
else:
|
| 114 |
+
u_star = -agents[n, [1, 5]] / T
|
| 115 |
+
agents[n] = Ad @ agents[n] + Bd @ np.clip(u_star, -1.0, 1.0)
|
| 116 |
+
agents[n, [0, 4]] = np.clip(agents[n, [0, 4]], 0.1, 99.9)
|
| 117 |
+
vel = agents[n, [1, 5]]; spd = np.linalg.norm(vel)
|
| 118 |
+
if spd > max_vel: agents[n, [1, 5]] = (vel / spd) * max_vel
|
| 119 |
+
|
| 120 |
+
if t % 100 == 0:
|
| 121 |
+
ax.clear()
|
| 122 |
+
ax.contour(xi, yi, zi, levels=10, colors='black', linewidths=0.3, alpha=0.3)
|
| 123 |
+
consensus_rem = np.min(agent_betas, axis=0)
|
| 124 |
+
consumed = consensus_rem <= 1e-8
|
| 125 |
+
ax.scatter(y_ref[~consumed][:,0], y_ref[~consumed][:,1], c='#007BFF', s=2, alpha=0.2)
|
| 126 |
+
for n in range(no_of_agents):
|
| 127 |
+
path_data = np.array(trajectories[n])
|
| 128 |
+
ax.plot(path_data[:, 0], path_data[:, 1], lw=1, alpha=0.6)
|
| 129 |
+
ax.plot(agents[n, 0], agents[n, 4], 'o', ms=6)
|
| 130 |
+
ax.set_xlim(0, 100); ax.set_ylim(0, 100)
|
| 131 |
+
ax.set_title(f"Coverage Progress: {(1-np.sum(consensus_rem))*100:.1f}%")
|
| 132 |
+
fig.canvas.draw()
|
| 133 |
+
image = Image.fromarray(np.array(fig.canvas.renderer.buffer_rgba())).convert("RGB")
|
| 134 |
+
frames.append(image)
|
| 135 |
+
if (1-np.sum(consensus_rem)) > 0.95: break
|
| 136 |
+
|
| 137 |
+
gif_path = "output.gif"
|
| 138 |
+
frames[0].save(gif_path, save_all=True, append_images=frames[1:], duration=150, loop=0)
|
| 139 |
+
plt.close(fig)
|
| 140 |
+
return gif_path
|
| 141 |
+
|
| 142 |
+
# 4. Gradio UI Configuration
|
| 143 |
+
with gr.Blocks() as demo:
|
| 144 |
+
gr.Markdown("# 🛸 D2OC: Density-Driven Optimal Control")
|
| 145 |
+
gr.Markdown("Interactive Demo for Decentralized Multi-Agent Coverage using Optimal Transport (OT) and Wasserstein Distance.")
|
| 146 |
+
with gr.Row():
|
| 147 |
+
num_agents = gr.Slider(minimum=2, maximum=20, value=8, step=1, label="Number of Agents")
|
| 148 |
+
bat_life = gr.Slider(minimum=1000, maximum=50000, value=10000, step=1000, label="Task Capacity (Battery Life)")
|
| 149 |
+
btn = gr.Button("Run Simulation")
|
| 150 |
+
output_gif = gr.Image(label="Simulation Result (GIF)")
|
| 151 |
+
btn.click(fn=run_simulation, inputs=[num_agents, bat_life], outputs=output_gif)
|
| 152 |
+
|
| 153 |
+
demo.launch()
|