TahaRasouli commited on
Commit
9d60b9f
·
verified ·
1 Parent(s): 4d2d459

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +745 -0
app.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import networkx as nx
3
+ import matplotlib.pyplot as plt
4
+ import random
5
+ import time
6
+ import json
7
+ import os
8
+ import shutil
9
+ import zipfile
10
+ from datetime import datetime
11
+
12
+ # ==========================================
13
+ # 1. JSON EXPORT LOGIC (Adapted from your script)
14
+ # ==========================================
15
+ def prepare_edges_for_json(G):
16
+ # Sort nodes by X, then Y to ensure consistent IDs
17
+ # G.nodes() are tuples (x, y)
18
+ nodes_list = sorted(list(G.nodes()), key=lambda l: (l[0], l[1]))
19
+
20
+ # Map (x,y) -> "1", "2", "3"...
21
+ nodes_list_dict = {}
22
+ I = []
23
+ for idx, node in enumerate(nodes_list):
24
+ s_id = str(idx + 1)
25
+ I.append(s_id)
26
+ nodes_list_dict[s_id] = node
27
+
28
+ # Create reverse map for easy lookup
29
+ coord_to_id = {v: k for k, v in nodes_list_dict.items()}
30
+
31
+ edges_list = list(G.edges())
32
+ edges_formatted = []
33
+
34
+ for u, v in edges_list:
35
+ if u in coord_to_id and v in coord_to_id:
36
+ edges_formatted.append({
37
+ "room1": coord_to_id[u],
38
+ "room2": coord_to_id[v]
39
+ })
40
+
41
+ return edges_formatted, I, nodes_list_dict
42
+
43
+ def prepare_parameter_for_json(G, I, nodes_list_dict):
44
+ # Use deterministic randomness based on graph hash or similar if needed,
45
+ # but here we use standard random as per requirements.
46
+
47
+ # 1. Weights for dismantling (prefer front rooms)
48
+ weights = []
49
+ n_count = len(G.nodes())
50
+ for i in range(n_count):
51
+ # Formula from your script
52
+ val = n_count / (n_count * (1 + (((i + 1) * 2) / 30)))
53
+ weights.append(val)
54
+
55
+ m_weights = random.choices(I, weights=weights, k=5)
56
+
57
+ # 2. Weights for time (prefer early times)
58
+ t_weights_probs = []
59
+ for i in range(10):
60
+ val = n_count / (n_count * (1 + (((i + 1) * 2) / 5)))
61
+ t_weights_probs.append(val)
62
+ t_weights = random.choices(range(1, 11), weights=t_weights_probs, k=5)
63
+
64
+ dismantled = []
65
+ conditioningDuration = []
66
+ assignment = []
67
+ help_list = []
68
+
69
+ # 3. Build Dismantled / Assignment
70
+ for m in range(5):
71
+ dismantled.append({"m": str(m + 1), "i": str(m_weights[m]), "t": t_weights[m], "value": 1})
72
+ conditioningDuration.append({"m": str(m + 1), "value": 1})
73
+
74
+ # Random device selection with fallback logic
75
+ x = random.randint(1, 3)
76
+ if m > 2:
77
+ if 1 not in help_list: x = 1
78
+ if 2 not in help_list: x = 2
79
+ if 3 not in help_list: x = 3
80
+
81
+ help_list.append(x)
82
+ assignment.append({"m": str(m + 1), "r": str(x), "value": 1})
83
+
84
+ # 4. Delivery
85
+ t_weights_del = random.choices(range(1, 11), weights=t_weights_probs[:10], k=3)
86
+ delivered = []
87
+ conditioningCapacity = []
88
+
89
+ for r in range(3):
90
+ delivered.append({"r": str(r + 1), "i": "1", "t": t_weights_del[r], "value": 1})
91
+ conditioningCapacity.append({"r": str(r + 1), "value": 1})
92
+
93
+ # 5. Costs
94
+ CostMT, CostMB, CostRT, CostRB = [], [], [], []
95
+
96
+ for i in range(n_count):
97
+ s_id = str(i + 1)
98
+ CostMT.append({"i": s_id, "value": random.choice([2, 5])})
99
+ CostMB.append({"i": s_id, "value": random.choice([5, 10, 30])})
100
+ CostRT.append({"i": s_id, "value": random.choice([4, 10])})
101
+
102
+ if i == 0:
103
+ CostRB.append({"i": s_id, "value": 1000})
104
+ else:
105
+ CostRB.append({"i": s_id, "value": random.choice([20, 30, 100])})
106
+
107
+ # 6. Coordinates
108
+ Coord = []
109
+ # nodes_list_dict maps "1" -> (x,y)
110
+ for i in range(n_count):
111
+ s_id = str(i + 1)
112
+ if s_id in nodes_list_dict:
113
+ Coord.append({"i": s_id, "Coordinates": nodes_list_dict[s_id]})
114
+
115
+ return dismantled, assignment, delivered, conditioningCapacity, conditioningDuration, CostMT, CostMB, CostRT, CostRB, Coord
116
+
117
+ def generate_full_json_dict(G, loop=0):
118
+ edges, I, nodes_list_dict = prepare_edges_for_json(G)
119
+ dismantled, assignment, delivered, condCap, condDur, CostMT, CostMB, CostRT, CostRB, Coord = prepare_parameter_for_json(G, I, nodes_list_dict)
120
+
121
+ sets = {
122
+ "I": I,
123
+ "E": {"bidirectional": True, "seed": 1, "edges": edges},
124
+ "M": ["1", "2", "3", "4", "5"],
125
+ "R": ["1", "2", "3"]
126
+ }
127
+
128
+ params = {
129
+ "defaults": {
130
+ "V": 1000, "dismantled": 0, "delivered": 0,
131
+ "conditioningCapacity": 1000, "conditioningDuration": 1,
132
+ "CostMB": 100, "CostMT": 20, "CostRB": 300, "CostRT": 50,
133
+ "assignment": 0, "dismantled_room_bound": 0,
134
+ "CostFI": 20, "CostVI": 20
135
+ },
136
+ "t_max": 100,
137
+ "V": [{"m": "1", "i": "1", "value": 42}],
138
+ "dismantled": dismantled,
139
+ "delivered": delivered,
140
+ "conditioningCapacity": condCap,
141
+ "conditioningDuration": condDur,
142
+ "assignment": assignment,
143
+ "CostMT": CostMT, "CostMB": CostMB,
144
+ "CostRT": CostRT, "CostRB": CostRB,
145
+ "CostZR": 9, "CostZH": 5,
146
+ "Coord": Coord
147
+ }
148
+
149
+ return {"description": "Generated by Gradio Network Generator", "sets": sets, "params": params}
150
+
151
+ # ==========================================
152
+ # 2. CORE LOGIC: NETWORK GENERATOR CLASS
153
+ # ==========================================
154
+ class NetworkGenerator:
155
+ def __init__(self, width=10, height=10, variant="F", topology="highly_connected",
156
+ node_drop_fraction=0.1, target_nodes=0, target_edges=0,
157
+ bottleneck_cluster_count=None, bottleneck_edges_per_link=1):
158
+
159
+ self.variant = variant.upper()
160
+ self.topology = topology.lower()
161
+ self.width = int(width)
162
+ self.height = int(height)
163
+ self.node_drop_fraction = float(node_drop_fraction)
164
+
165
+ # New Target Controls
166
+ self.target_nodes = int(target_nodes)
167
+ self.target_edges = int(target_edges)
168
+
169
+ self.node_factor = 0.4
170
+ if bottleneck_cluster_count is None:
171
+ area = self.width * self.height
172
+ self.bottleneck_cluster_count = max(2, int(area / 18))
173
+ else:
174
+ self.bottleneck_cluster_count = int(bottleneck_cluster_count)
175
+
176
+ self.bottleneck_edges_per_link = int(bottleneck_edges_per_link)
177
+ self.graph = None
178
+ self.active_positions = None
179
+
180
+ def generate(self):
181
+ max_attempts = 15
182
+ for attempt in range(max_attempts):
183
+ self._build_node_mask()
184
+ self._initialize_graph()
185
+ self._add_nodes() # Handles target_nodes logic inside
186
+
187
+ nodes = list(self.graph.nodes())
188
+ if len(nodes) < 2: continue
189
+
190
+ # Topology Build
191
+ if self.topology == "bottlenecks":
192
+ self._build_bottleneck_clusters(nodes)
193
+ else:
194
+ self._connect_all_nodes_by_nearby_growth(nodes)
195
+ self._add_edges()
196
+
197
+ # Cleanup
198
+ self._remove_intersections()
199
+
200
+ # Post-Processing for Target Edges
201
+ if self.target_edges > 0:
202
+ self._adjust_edges_to_target()
203
+ else:
204
+ self._enforce_edge_budget()
205
+
206
+ if not nx.is_connected(self.graph):
207
+ self._force_connect_components()
208
+
209
+ self._remove_intersections()
210
+
211
+ # Final check: if target nodes was set, did we hit it?
212
+ # (We might miss slightly due to connectivity/intersection constraints, but we accept best effort)
213
+
214
+ if nx.is_connected(self.graph):
215
+ return self.graph
216
+
217
+ raise RuntimeError("Failed to generate valid network. Relax constraints.")
218
+
219
+ def _effective_node_drop_fraction(self):
220
+ # If target nodes is set, drop fraction is ignored/calculated dynamically
221
+ if self.target_nodes > 0: return 0.0
222
+
223
+ base = self.node_drop_fraction
224
+ if self.topology == "highly_connected": return max(0.0, base * 0.8)
225
+ if self.topology == "linear": return min(0.95, base * 1.2)
226
+ return base
227
+
228
+ def _build_node_mask(self):
229
+ all_positions = [(x, y) for x in range(self.width + 1) for y in range(self.height + 1)]
230
+
231
+ if self.target_nodes > 0:
232
+ # If explicit count requested, we don't drop randomly yet.
233
+ # We treat all as potentially active, _add_nodes will sample.
234
+ self.active_positions = set(all_positions)
235
+ else:
236
+ drop_frac = self._effective_node_drop_fraction()
237
+ drop = int(drop_frac * len(all_positions))
238
+ deactivated = set(random.sample(all_positions, drop)) if drop > 0 else set()
239
+ self.active_positions = set(all_positions) - deactivated
240
+
241
+ def _initialize_graph(self):
242
+ self.graph = nx.Graph()
243
+ margin_x = max(1, self.width // 4)
244
+ margin_y = max(1, self.height // 4)
245
+ low_x, high_x = margin_x, self.width - margin_x
246
+ low_y, high_y = margin_y, self.height - margin_y
247
+
248
+ # Prefer middle
249
+ middle_active = [p for p in self.active_positions if low_x <= p[0] <= high_x and low_y <= p[1] <= high_y]
250
+
251
+ if middle_active: seed = random.choice(middle_active)
252
+ elif self.active_positions: seed = random.choice(list(self.active_positions))
253
+ else: return
254
+ self.graph.add_node(tuple(seed))
255
+
256
+ def _add_nodes(self):
257
+ # Logic 1: Strict Target Count
258
+ if self.target_nodes > 0:
259
+ needed = self.target_nodes - len(self.graph.nodes())
260
+ if needed <= 0: return
261
+
262
+ available = list(self.active_positions - set(self.graph.nodes()))
263
+ if len(available) < needed:
264
+ # Take all
265
+ for n in available: self.graph.add_node(n)
266
+ else:
267
+ # Sample exact amount
268
+ chosen = random.sample(available, needed)
269
+ for n in chosen: self.graph.add_node(n)
270
+ return
271
+
272
+ # Logic 2: Standard Density-based
273
+ total_possible = (self.width + 1) * (self.height + 1)
274
+ base = self.node_factor if self.variant == "F" else random.uniform(0.3, 0.6)
275
+ scale = {"highly_connected": 1.2, "bottlenecks": 0.85, "linear": 0.75}.get(self.topology, 1.0)
276
+ target = int(base * scale * total_possible)
277
+ target = min(target, len(self.active_positions))
278
+
279
+ attempts = 0
280
+ while len(self.graph.nodes()) < target and attempts < (target * 20):
281
+ attempts += 1
282
+ x = random.randint(0, self.width)
283
+ y = random.randint(0, self.height)
284
+ if (x, y) in self.active_positions and (x, y) not in self.graph:
285
+ self.graph.add_node((x, y))
286
+
287
+ def _connect_all_nodes_by_nearby_growth(self, nodes):
288
+ connected = set()
289
+ remaining = set(nodes)
290
+ if not remaining: return
291
+ current = random.choice(nodes)
292
+ connected.add(current)
293
+ remaining.remove(current)
294
+
295
+ while remaining:
296
+ candidates = []
297
+ # Optimization: Check nearby only
298
+ for n in remaining:
299
+ # Heuristic: check if any connected is close
300
+ # Full scan is slow for large N, but necessary for correctness
301
+ closest_dist = min([abs(n[0]-c[0]) + abs(n[1]-c[1]) for c in connected])
302
+ if closest_dist <= 4: # Manhattan dist check
303
+ candidates.append(n)
304
+
305
+ if not candidates:
306
+ # Fallback: connect closest pair globally
307
+ best_n = min(remaining, key=lambda r: min(abs(r[0]-c[0]) + abs(r[1]-c[1]) for c in connected))
308
+ candidates.append(best_n)
309
+
310
+ candidate = random.choice(candidates)
311
+
312
+ # Find closest connected node
313
+ neighbors = sorted(list(connected), key=lambda c: abs(c[0]-candidate[0]) + abs(c[1]-candidate[1]))
314
+ # Try to connect to closest 3
315
+ for n in neighbors[:3]:
316
+ if not self._would_create_intersection(n, candidate):
317
+ self.graph.add_edge(n, candidate)
318
+ break
319
+ else:
320
+ # Force connect closest if no non-intersecting found (will be cleaned later)
321
+ self.graph.add_edge(neighbors[0], candidate)
322
+
323
+ connected.add(candidate)
324
+ remaining.remove(candidate)
325
+
326
+ def _compute_edge_count(self):
327
+ if self.target_edges > 0: return self.target_edges
328
+ n = len(self.graph.nodes())
329
+ if self.topology == "highly_connected": return int(3.5 * n)
330
+ if self.topology == "bottlenecks": return int(1.8 * n)
331
+ return int(random.uniform(1.2, 2.0) * n)
332
+
333
+ def _add_edges(self):
334
+ nodes = list(self.graph.nodes())
335
+ if self.topology == "highly_connected": self._add_cluster_dense(nodes, self._compute_edge_count())
336
+ elif self.topology == "linear": self._make_linear(nodes)
337
+
338
+ def _make_linear(self, nodes):
339
+ nodes_sorted = sorted(nodes, key=lambda x: (x[0], x[1]))
340
+ if not nodes_sorted: return
341
+ prev = nodes_sorted[0]
342
+ for nxt in nodes_sorted[1:]:
343
+ if not self._would_create_intersection(prev, nxt): self.graph.add_edge(prev, nxt)
344
+ prev = nxt
345
+
346
+ def _add_cluster_dense(self, nodes, max_edges):
347
+ edges_added = 0
348
+ nodes = list(nodes)
349
+ random.shuffle(nodes)
350
+
351
+ # If target edges set, we might need a lot, so loosen distance
352
+ dist_limit = 10 if self.target_edges > 0 else 4
353
+
354
+ for i in range(len(nodes)):
355
+ for j in range(i + 1, len(nodes)):
356
+ if self.target_edges == 0 and edges_added >= max_edges: return
357
+ n1, n2 = nodes[i], nodes[j]
358
+ dist = max(abs(n1[0]-n2[0]), abs(n1[1]-n2[1]))
359
+ if dist <= dist_limit:
360
+ if not self._would_create_intersection(n1, n2):
361
+ self.graph.add_edge(n1, n2)
362
+ edges_added += 1
363
+
364
+ def _build_bottleneck_clusters(self, nodes):
365
+ self.graph.remove_edges_from(list(self.graph.edges()))
366
+ clusters, centers = self._spatial_cluster_nodes(nodes, k=self.bottleneck_cluster_count)
367
+ for cluster in clusters:
368
+ if len(cluster) < 2: continue
369
+ self._connect_cluster_by_nearby_growth(cluster)
370
+ self._add_cluster_dense(list(cluster), max_edges=max(1, int(3.5 * len(cluster))))
371
+ order = sorted(range(len(clusters)), key=lambda i: (centers[i][0], centers[i][1]))
372
+ for a_idx, b_idx in zip(order[:-1], order[1:]):
373
+ self._add_bottleneck_links(clusters[a_idx], clusters[b_idx], self.bottleneck_edges_per_link)
374
+ if not nx.is_connected(self.graph): self._force_connect_components()
375
+
376
+ def _force_connect_components(self):
377
+ components = list(nx.connected_components(self.graph))
378
+ while len(components) > 1:
379
+ c1, c2 = list(components[0]), list(components[1])
380
+ best_pair, min_dist = None, float('inf')
381
+
382
+ # Sample for speed if huge
383
+ s1 = c1 if len(c1)<30 else random.sample(c1, 30)
384
+ s2 = c2 if len(c2)<30 else random.sample(c2, 30)
385
+
386
+ for u in s1:
387
+ for v in s2:
388
+ d = (u[0]-v[0])**2 + (u[1]-v[1])**2
389
+ if d < min_dist and not self._would_create_intersection(u, v):
390
+ min_dist, best_pair = d, (u, v)
391
+
392
+ if best_pair: self.graph.add_edge(best_pair[0], best_pair[1])
393
+ else: break # Cannot connect cleanly
394
+
395
+ prev_len = len(components)
396
+ components = list(nx.connected_components(self.graph))
397
+ if len(components) == prev_len: break
398
+
399
+ def _spatial_cluster_nodes(self, nodes, k):
400
+ nodes = list(nodes)
401
+ if k >= len(nodes): return [[n] for n in nodes], nodes[:]
402
+ centers = random.sample(nodes, k)
403
+ clusters = [[] for _ in range(k)]
404
+ for n in nodes:
405
+ best_i = min(range(k), key=lambda i: max(abs(n[0]-centers[i][0]), abs(n[1]-centers[i][1])))
406
+ clusters[best_i].append(n)
407
+ return clusters, centers
408
+
409
+ def _connect_cluster_by_nearby_growth(self, cluster_nodes): self._connect_all_nodes_by_nearby_growth(cluster_nodes)
410
+
411
+ def _add_bottleneck_links(self, cluster_a, cluster_b, m):
412
+ pairs = []
413
+ for u in cluster_a:
414
+ for v in cluster_b:
415
+ dist = max(abs(u[0]-v[0]), abs(u[1]-v[1]))
416
+ pairs.append((dist, u, v))
417
+ pairs.sort(key=lambda t: t[0])
418
+ added = 0
419
+ for _, u, v in pairs:
420
+ if added >= m: break
421
+ if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
422
+ self.graph.add_edge(u, v)
423
+ added += 1
424
+
425
+ def _remove_intersections(self):
426
+ # Full check
427
+ pass_no = 0
428
+ while pass_no < 5:
429
+ pass_no += 1
430
+ edges = list(self.graph.edges())
431
+ intersections = []
432
+ # Check subset if massive
433
+ if len(edges) > 600:
434
+ check_edges = random.sample(edges, 400)
435
+ else:
436
+ check_edges = edges
437
+
438
+ for i in range(len(check_edges)):
439
+ for j in range(i+1, len(check_edges)):
440
+ e1, e2 = check_edges[i], check_edges[j]
441
+ if self._segments_intersect(e1[0], e1[1], e2[0], e2[1]): intersections.append((e1, e2))
442
+
443
+ if not intersections: break
444
+ for e1, e2 in intersections:
445
+ if not self.graph.has_edge(*e1) or not self.graph.has_edge(*e2): continue
446
+ l1 = (e1[0][0]-e1[1][0])**2 + (e1[0][1]-e1[1][1])**2
447
+ l2 = (e2[0][0]-e2[1][0])**2 + (e2[0][1]-e2[1][1])**2
448
+ self.graph.remove_edge(e1 if l1 > l2 else e2)
449
+
450
+ def _adjust_edges_to_target(self):
451
+ # If target_edges is set, we strictly add or remove
452
+ current_edges = list(self.graph.edges())
453
+ curr_count = len(current_edges)
454
+
455
+ # Case 1: Too many
456
+ if curr_count > self.target_edges:
457
+ to_remove = curr_count - self.target_edges
458
+ # remove longest first
459
+ sorted_edges = sorted(current_edges, key=lambda e: (e[0][0]-e[1][0])**2 + (e[0][1]-e[1][1])**2, reverse=True)
460
+ for e in sorted_edges:
461
+ if len(self.graph.edges()) <= self.target_edges: break
462
+ self.graph.remove_edge(*e)
463
+ if not nx.is_connected(self.graph):
464
+ self.graph.add_edge(*e) # Put it back if it breaks connectivity
465
+
466
+ # Case 2: Too few
467
+ elif curr_count < self.target_edges:
468
+ needed = self.target_edges - curr_count
469
+ nodes = list(self.graph.nodes())
470
+ attempts = 0
471
+ while len(self.graph.edges()) < self.target_edges and attempts < (needed * 20):
472
+ attempts += 1
473
+ u, v = random.sample(nodes, 2)
474
+ if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
475
+ # Check distance sanity (don't connect across map randomly unless desperate)
476
+ dist = abs(u[0]-v[0]) + abs(u[1]-v[1])
477
+ if dist < max(self.width, self.height) / 2:
478
+ self.graph.add_edge(u, v)
479
+
480
+ def _enforce_edge_budget(self):
481
+ budget = self._compute_edge_count()
482
+ while len(self.graph.edges()) > budget:
483
+ edges = list(self.graph.edges())
484
+ rem = random.choice(edges)
485
+ self.graph.remove_edge(*rem)
486
+ if not nx.is_connected(self.graph):
487
+ self.graph.add_edge(*rem)
488
+ break
489
+
490
+ def _segments_intersect(self, a, b, c, d):
491
+ if a == c or a == d or b == c or b == d: return False
492
+ def ccw(A,B,C): return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])
493
+ return ccw(a,c,d) != ccw(b,c,d) and ccw(a,b,c) != ccw(a,b,d)
494
+
495
+ def _would_create_intersection(self, u, v):
496
+ for a, b in self.graph.edges():
497
+ if u == a or u == b or v == a or v == b: continue
498
+ if self._segments_intersect(u, v, a, b): return True
499
+ return False
500
+
501
+ # === MANUAL EDITING METHODS ===
502
+ def manual_add_node(self, x, y):
503
+ # 1. Check bounds
504
+ if not (0 <= x <= self.width and 0 <= y <= self.height):
505
+ return False, "Coordinates out of bounds."
506
+ # 2. Check existence
507
+ if self.graph.has_node((x, y)):
508
+ return False, "Node already exists."
509
+
510
+ self.graph.add_node((x, y))
511
+ # Connect to nearest neighbor to maintain connectivity
512
+ nodes = list(self.graph.nodes())
513
+ if len(nodes) > 1:
514
+ # find closest
515
+ closest = min([n for n in nodes if n != (x,y)],
516
+ key=lambda n: (n[0]-x)**2 + (n[1]-y)**2)
517
+ if not self._would_create_intersection((x,y), closest):
518
+ self.graph.add_edge((x,y), closest)
519
+
520
+ return True, "Node added."
521
+
522
+ def manual_delete_node(self, x, y):
523
+ if not self.graph.has_node((x, y)):
524
+ return False, "Node does not exist."
525
+
526
+ self.graph.remove_node((x, y))
527
+
528
+ # Check connectivity
529
+ if len(self.graph.nodes()) > 1 and not nx.is_connected(self.graph):
530
+ # Try to repair? Or just warn?
531
+ # For manual edits, we usually allow disjoint temporarily,
532
+ # but let's try to reconnect components
533
+ self._force_connect_components()
534
+
535
+ return True, "Node removed."
536
+
537
+
538
+ # ==========================================
539
+ # GRADIO HELPERS
540
+ # ==========================================
541
+
542
+ def plot_graph(graph, width, height, title="Network"):
543
+ fig, ax = plt.subplots(figsize=(8, 8))
544
+ pos = {node: (node[0], node[1]) for node in graph.nodes()}
545
+
546
+ nx.draw_networkx_edges(graph, pos, ax=ax, width=2, alpha=0.6, edge_color="#333")
547
+ nx.draw_networkx_nodes(graph, pos, ax=ax, node_size=350, node_color="#4F46E5", edgecolors="white", linewidths=1.5)
548
+
549
+ # Label mapping (coord -> id) to match JSON output
550
+ # Sort by X, Y
551
+ sorted_nodes = sorted(list(graph.nodes()), key=lambda l: (l[0], l[1]))
552
+ labels = {node: str(i+1) for i, node in enumerate(sorted_nodes)}
553
+
554
+ nx.draw_networkx_labels(graph, pos, labels, ax=ax, font_size=8, font_color="white", font_weight="bold")
555
+
556
+ ax.set_xlim(-1, width + 1)
557
+ ax.set_ylim(-1, height + 1)
558
+ ax.invert_yaxis()
559
+ ax.grid(True, linestyle=':', alpha=0.3)
560
+ ax.set_axis_on()
561
+ ax.tick_params(left=True, bottom=True, labelleft=False, labelbottom=False)
562
+ ax.set_title(title)
563
+ return fig
564
+
565
+ def get_preset_dims(preset_mode, topology):
566
+ if preset_mode == "Custom":
567
+ return gr.update(interactive=True), gr.update(interactive=True)
568
+ if topology == "linear":
569
+ dims = (4, 4) if preset_mode == "Small" else (6, 11) if preset_mode == "Medium" else (10, 26)
570
+ else:
571
+ dims = (4, 4) if preset_mode == "Small" else (8, 8) if preset_mode == "Medium" else (16, 16)
572
+ return gr.update(value=dims[0], interactive=False), gr.update(value=dims[1], interactive=False)
573
+
574
+ def update_void_settings(variant, width, height):
575
+ if variant == "Custom": return gr.update(interactive=True)
576
+ area = width * height
577
+ val = 0.60 if area <= 20 else 0.35
578
+ return gr.update(value=val, interactive=False)
579
+
580
+ # STATE HANDLER
581
+ def generate_and_store(topology, width, height, variant, void_frac, t_nodes, t_edges):
582
+ try:
583
+ var_code = "F" if variant == "Fixed" else "R"
584
+ gen = NetworkGenerator(width, height, var_code, topology, void_frac, t_nodes, t_edges)
585
+ graph = gen.generate()
586
+
587
+ fig = plot_graph(graph, width, height, f"{topology} ({len(graph.nodes())}N, {len(graph.edges())}E)")
588
+
589
+ metrics = f"**Nodes:** {len(graph.nodes())} | **Edges:** {len(graph.edges())} | **Density:** {nx.density(graph):.2f}"
590
+
591
+ # Store graph and params in state
592
+ state_data = {
593
+ "graph": graph,
594
+ "width": width,
595
+ "height": height,
596
+ "topology": topology
597
+ }
598
+ return fig, metrics, state_data, gr.update(interactive=True) # Enable edit/save
599
+ except Exception as e:
600
+ return None, f"Error: {e}", None, gr.update(interactive=False)
601
+
602
+ def manual_edit_action(action, x, y, state_data):
603
+ if not state_data or "graph" not in state_data:
604
+ return None, "No graph generated yet.", state_data
605
+
606
+ graph = state_data["graph"]
607
+ width = state_data["width"]
608
+ height = state_data["height"]
609
+
610
+ # We need a wrapper to call manual methods
611
+ # (Since methods are on class instance, but we only stored graph.
612
+ # We can briefly instantiate class or just modify graph directly using class logic methods)
613
+ # Re-instantiating is safer for accessing helper methods
614
+
615
+ gen = NetworkGenerator(width, height)
616
+ gen.graph = graph
617
+
618
+ if action == "Add Node":
619
+ success, msg = gen.manual_add_node(int(x), int(y))
620
+ else:
621
+ success, msg = gen.manual_delete_node(int(x), int(y))
622
+
623
+ if success:
624
+ fig = plot_graph(gen.graph, width, height, f"Edited ({len(gen.graph.nodes())}N)")
625
+ metrics = f"**Nodes:** {len(gen.graph.nodes())} | **Edges:** {len(gen.graph.edges())} | {msg}"
626
+ state_data["graph"] = gen.graph # Update state
627
+ return fig, metrics, state_data
628
+ else:
629
+ return gr.update(), f"Error: {msg}", state_data
630
+
631
+ def batch_save_action(count, state_data):
632
+ if not state_data: return None
633
+
634
+ # We reconstruct the parameters from the state (or UI inputs if passed)
635
+ # For batch, users usually want variations of the CURRENT settings.
636
+ # However, Gradio state only stored the Graph. We need the original params.
637
+ # To keep it simple: We will just re-generate purely new graphs based on CURRENT UI inputs inside the loop.
638
+ return None # Placeholder, logic moved to UI event for access to inputs
639
+
640
+ def run_batch_generation(count, topology, width, height, variant, void_frac, t_nodes, t_edges):
641
+ # Temp dir
642
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
643
+ dir_name = f"batch_output_{timestamp}"
644
+ os.makedirs(dir_name, exist_ok=True)
645
+
646
+ var_code = "F" if variant == "Fixed" else "R"
647
+
648
+ try:
649
+ for i in range(int(count)):
650
+ gen = NetworkGenerator(width, height, var_code, topology, void_frac, t_nodes, t_edges)
651
+ G = gen.generate()
652
+
653
+ # Create JSON content
654
+ json_content = generate_full_json_dict(G, loop=i+1)
655
+
656
+ fname = f"instance_{timestamp}_{i+1}.json"
657
+ with open(os.path.join(dir_name, fname), 'w') as f:
658
+ json.dump(json_content, f, indent=4)
659
+
660
+ # Zip it
661
+ zip_filename = f"{dir_name}.zip"
662
+ shutil.make_archive(dir_name, 'zip', dir_name)
663
+
664
+ # Cleanup temp folder
665
+ shutil.rmtree(dir_name)
666
+
667
+ return f"{dir_name}.zip"
668
+ except Exception as e:
669
+ return None
670
+
671
+ # ==========================================
672
+ # GRADIO UI LAYOUT
673
+ # ==========================================
674
+ with gr.Blocks(title="Graph Generator Pro") as demo:
675
+ state = gr.State() # Stores current graph object
676
+
677
+ gr.Markdown("# Spatial Network Generator Pro")
678
+
679
+ with gr.Row():
680
+ # LEFT: Settings
681
+ with gr.Column(scale=1):
682
+ with gr.Tab("Config"):
683
+ topology = gr.Dropdown(["highly_connected", "bottlenecks", "linear"], value="highly_connected", label="Topology")
684
+ preset = gr.Radio(["Small", "Medium", "Large", "Custom"], value="Medium", label="Preset")
685
+
686
+ with gr.Row():
687
+ width = gr.Number(8, label="Width", precision=0, interactive=False)
688
+ height = gr.Number(8, label="Height", precision=0, interactive=False)
689
+
690
+ variant = gr.Dropdown(["Fixed", "Custom"], value="Fixed", label="Variant")
691
+ void_frac = gr.Slider(0.0, 0.9, 0.35, step=0.05, label="Void Fraction", interactive=False)
692
+
693
+ gr.Markdown("### Overrides (Optional)")
694
+ t_nodes = gr.Number(0, label="Target Node Count (0=Auto)", precision=0)
695
+ t_edges = gr.Number(0, label="Target Edge Count (0=Auto)", precision=0)
696
+
697
+ gen_btn = gr.Button("Generate Network", variant="primary")
698
+
699
+ with gr.Tab("Editor"):
700
+ gr.Markdown("Modify the current graph manually.")
701
+ with gr.Row():
702
+ ed_x = gr.Number(0, label="X", precision=0)
703
+ ed_y = gr.Number(0, label="Y", precision=0)
704
+
705
+ with gr.Row():
706
+ btn_add = gr.Button("Add Node")
707
+ btn_del = gr.Button("Delete Node")
708
+
709
+ with gr.Tab("Batch Export"):
710
+ gr.Markdown("Generate multiple variations and save as JSONs.")
711
+ batch_count = gr.Slider(1, 50, 5, step=1, label="Number of Variations")
712
+ batch_btn = gr.Button("Generate & Download Batch ZIP")
713
+ file_out = gr.File(label="Download")
714
+
715
+ # RIGHT: Visualization
716
+ with gr.Column(scale=2):
717
+ metrics = gr.Markdown("Ready.")
718
+ plot = gr.Plot()
719
+
720
+ # EVENTS
721
+
722
+ # 1. Preset & Void Logic
723
+ inputs_dims = [preset, topology]
724
+ preset.change(get_preset_dims, inputs_dims, [width, height])
725
+ topology.change(get_preset_dims, inputs_dims, [width, height])
726
+
727
+ inputs_void = [variant, width, height]
728
+ variant.change(update_void_settings, inputs_void, [void_frac])
729
+ width.change(update_void_settings, inputs_void, [void_frac])
730
+ height.change(update_void_settings, inputs_void, [void_frac])
731
+
732
+ # 2. Generation
733
+ gen_args = [topology, width, height, variant, void_frac, t_nodes, t_edges]
734
+ gen_btn.click(generate_and_store, gen_args, [plot, metrics, state])
735
+
736
+ # 3. Manual Editing
737
+ btn_add.click(manual_edit_action, [gr.State("Add Node"), ed_x, ed_y, state], [plot, metrics, state])
738
+ btn_del.click(manual_edit_action, [gr.State("Del Node"), ed_x, ed_y, state], [plot, metrics, state])
739
+
740
+ # 4. Batch
741
+ batch_args = [batch_count, topology, width, height, variant, void_frac, t_nodes, t_edges]
742
+ batch_btn.click(run_batch_generation, batch_args, [file_out])
743
+
744
+ if __name__ == "__main__":
745
+ demo.launch()