D2OC commited on
Commit
ba8951b
·
verified ·
1 Parent(s): a7a4de8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+ from scipy.stats import gaussian_kde
6
+ import os
7
+
8
+ # 1. Hamiltonian Optimal Control Function
9
+ def hamilton_optimal_control(x, alpha_k, z, local_weight, pos_idx, d_sq):
10
+ if len(pos_idx) == 0: return None, 0
11
+ w_sub = local_weight[pos_idx]
12
+ d_wnE = d_sq / (w_sub**2 + 1e-12)
13
+ idx_sorted = np.argsort(d_wnE)
14
+ sorted_pos_idx = pos_idx[idx_sorted]
15
+ sorted_w = local_weight[sorted_pos_idx]
16
+ cumsum_w = np.cumsum(sorted_w)
17
+ cutoff = np.searchsorted(cumsum_w, alpha_k)
18
+ if cutoff >= len(sorted_pos_idx): cutoff = len(sorted_pos_idx) - 1
19
+ final_idx = sorted_pos_idx[:cutoff+1]
20
+ final_w = sorted_w[:cutoff+1].copy()
21
+ if cutoff > 0: final_w[-1] = alpha_k - cumsum_w[cutoff-1]
22
+ else: final_w[-1] = alpha_k
23
+ y_loc_pos = z[final_idx]
24
+ y_loc_wgt = final_w.reshape(-1, 1)
25
+ gamma = np.sum(y_loc_wgt)
26
+ xg = np.sum(y_loc_wgt * y_loc_pos, axis=0) / (gamma + 1e-12)
27
+ return xg, gamma
28
+
29
+ # 2. Local Weight Update Function
30
+ def update_weight_local(x, z, local_weight, alpha_k):
31
+ pos_idx = np.where(local_weight > 1e-15)[0]
32
+ if len(pos_idx) == 0: return local_weight, np.zeros_like(local_weight), pos_idx, np.array([])
33
+ dist_sq = np.sum((z[pos_idx] - x)**2, axis=1)
34
+ idx_sort = np.argsort(dist_sq)
35
+ sorted_idx = pos_idx[idx_sort]
36
+ sorted_w = local_weight[sorted_idx]
37
+ cumsum_w = np.cumsum(sorted_w)
38
+ cutoff = np.searchsorted(cumsum_w, alpha_k)
39
+ weight_dist = np.zeros_like(local_weight)
40
+ if cutoff < len(sorted_idx):
41
+ take_idx = sorted_idx[:cutoff]
42
+ weight_dist[take_idx] = local_weight[take_idx]; local_weight[take_idx] = 0
43
+ remainder = alpha_k - (cumsum_w[cutoff-1] if cutoff > 0 else 0)
44
+ weight_dist[sorted_idx[cutoff]] = remainder; local_weight[sorted_idx[cutoff]] -= remainder
45
+ else:
46
+ weight_dist[sorted_idx] = local_weight[sorted_idx]; local_weight[sorted_idx] = 0
47
+ local_weight[local_weight < 1e-12] = 0
48
+ new_pos_idx = np.where(local_weight > 1e-15)[0]
49
+ new_dist_sq = np.sum((z[new_pos_idx] - x)**2, axis=1) if len(new_pos_idx) > 0 else np.array([])
50
+ return local_weight, weight_dist, new_pos_idx, new_dist_sq
51
+
52
+ # 3. Main Simulation Logic
53
+ def run_simulation(num_agents, battery_life):
54
+ no_of_agents = int(num_agents)
55
+ bat_life = int(battery_life)
56
+ T, g, comm_range = 0.1, 9.81, 15.0
57
+ alpha_k = 1.0 / bat_life
58
+ frames = []
59
+
60
+ Ad_sub = np.array([[1, T, 0, 0], [0, 0.7, -T*g, 0], [0, 0, 1, T], [0, 0, 0, 0.7]])
61
+ Ad = np.zeros((8, 8)); Ad[:4, :4] = Ad_sub; Ad[4:, 4:] = Ad_sub
62
+ Bd = np.zeros((8, 2)); Bd[3, 0] = T/0.1; Bd[7, 1] = T/0.1
63
+ CTC = np.zeros((8, 8)); CTC[0,0] = 1.0; CTC[4,4] = 1.0
64
+ Q_base, R_block = np.eye(8) * 0.1, np.eye(2) * 10.0
65
+ hor_leng = 5
66
+ E12 = np.zeros((40, 40))
67
+ for i in range(hor_leng):
68
+ E12[i*8:(i+1)*8, i*8:(i+1)*8] = -np.eye(8)
69
+ if i < hor_leng - 1: E12[i*8:(i+1)*8, (i+1)*8:(i+2)*8] = Ad.T
70
+ E12_inv = np.linalg.inv(E12)
71
+ E23 = np.kron(np.eye(hor_leng), Bd)
72
+ E33 = np.kron(np.eye(hor_leng), R_block)
73
+ E12_inv_E23 = E12_inv @ E23
74
+ E23T_E12inv = E23.T @ E12_inv
75
+
76
+ y_ref = np.vstack([np.random.multivariate_normal(np.random.uniform(20, 80, 2), np.eye(2)*15, 200) for _ in range(8)])
77
+ xi, yi = np.mgrid[0:100:50j, 0:100:50j]
78
+ positions = np.vstack([xi.ravel(), yi.ravel()])
79
+ kde = gaussian_kde(y_ref.T)
80
+ zi = np.reshape(kde(positions).T, xi.shape)
81
+
82
+ agent_betas = [np.ones(len(y_ref)) / len(y_ref) for _ in range(no_of_agents)]
83
+ agent_wgt_history = np.zeros((no_of_agents, no_of_agents, len(y_ref)))
84
+ comm_time, agents = np.zeros((no_of_agents, no_of_agents)), np.zeros((no_of_agents, 8))
85
+ trajectories = [[] for _ in range(no_of_agents)]
86
+
87
+ for n in range(no_of_agents):
88
+ agents[n, [0, 4]] = np.random.uniform(10, 90, 2)
89
+
90
+ fig, ax = plt.subplots(figsize=(6, 6))
91
+ max_vel = 0.8
92
+
93
+ for t in range(1, 1001):
94
+ for n in range(no_of_agents):
95
+ trajectories[n].append(agents[n, [0, 4]].copy())
96
+ agent_betas[n], delta_w, pos_idx, d_sq = update_weight_local(agents[n, [0, 4]], y_ref, agent_betas[n], alpha_k)
97
+ agent_wgt_history[n, n] += delta_w
98
+ comm_time[n, n] = t
99
+ xg, gamma = hamilton_optimal_control(agents[n, [0, 4]], alpha_k, y_ref, agent_betas[n], pos_idx, d_sq)
100
+ if xg is not None and gamma > 1e-12:
101
+ Q_track = (gamma * 250.0) * CTC + Q_base
102
+ xg_full = np.zeros(8); xg_full[0], xg_full[4] = xg[0], xg[1]
103
+ F1 = np.tile(Q_track @ xg_full, hor_leng); F2 = np.zeros(40); F2[:8] = -Ad @ agents[n]
104
+ M = E12_inv_E23
105
+ E11_M = np.zeros_like(M)
106
+ for i in range(hor_leng): E11_M[i*8:(i+1)*8, :] = Q_track @ M[i*8:(i+1)*8, :]
107
+ lhs = E33 + M.T @ E11_M
108
+ temp_F2_trans = E12_inv.T @ F2
109
+ E11_temp_F2 = np.zeros(40)
110
+ for i in range(hor_leng): E11_temp_F2[i*8:(i+1)*8] = Q_track @ temp_F2_trans[i*8:(i+1)*8]
111
+ rhs = -(E23T_E12inv @ F1) + E23.T @ (E12_inv @ E11_temp_F2)
112
+ u_star = np.linalg.solve(lhs, rhs)[:2]
113
+ else:
114
+ u_star = -agents[n, [1, 5]] / T
115
+ agents[n] = Ad @ agents[n] + Bd @ np.clip(u_star, -1.0, 1.0)
116
+ agents[n, [0, 4]] = np.clip(agents[n, [0, 4]], 0.1, 99.9)
117
+ vel = agents[n, [1, 5]]; spd = np.linalg.norm(vel)
118
+ if spd > max_vel: agents[n, [1, 5]] = (vel / spd) * max_vel
119
+
120
+ if t % 100 == 0:
121
+ ax.clear()
122
+ ax.contour(xi, yi, zi, levels=10, colors='black', linewidths=0.3, alpha=0.3)
123
+ consensus_rem = np.min(agent_betas, axis=0)
124
+ consumed = consensus_rem <= 1e-8
125
+ ax.scatter(y_ref[~consumed][:,0], y_ref[~consumed][:,1], c='#007BFF', s=2, alpha=0.2)
126
+ for n in range(no_of_agents):
127
+ path_data = np.array(trajectories[n])
128
+ ax.plot(path_data[:, 0], path_data[:, 1], lw=1, alpha=0.6)
129
+ ax.plot(agents[n, 0], agents[n, 4], 'o', ms=6)
130
+ ax.set_xlim(0, 100); ax.set_ylim(0, 100)
131
+ ax.set_title(f"Coverage Progress: {(1-np.sum(consensus_rem))*100:.1f}%")
132
+ fig.canvas.draw()
133
+ image = Image.fromarray(np.array(fig.canvas.renderer.buffer_rgba())).convert("RGB")
134
+ frames.append(image)
135
+ if (1-np.sum(consensus_rem)) > 0.95: break
136
+
137
+ gif_path = "output.gif"
138
+ frames[0].save(gif_path, save_all=True, append_images=frames[1:], duration=150, loop=0)
139
+ plt.close(fig)
140
+ return gif_path
141
+
142
+ # 4. Gradio UI Configuration
143
+ with gr.Blocks() as demo:
144
+ gr.Markdown("# 🛸 D2OC: Density-Driven Optimal Control")
145
+ gr.Markdown("Interactive Demo for Decentralized Multi-Agent Coverage using Optimal Transport (OT) and Wasserstein Distance.")
146
+ with gr.Row():
147
+ num_agents = gr.Slider(minimum=2, maximum=20, value=8, step=1, label="Number of Agents")
148
+ bat_life = gr.Slider(minimum=1000, maximum=50000, value=10000, step=1000, label="Task Capacity (Battery Life)")
149
+ btn = gr.Button("Run Simulation")
150
+ output_gif = gr.Image(label="Simulation Result (GIF)")
151
+ btn.click(fn=run_simulation, inputs=[num_agents, bat_life], outputs=output_gif)
152
+
153
+ demo.launch()