TahaRasouli commited on
Commit
c5b0086
·
verified ·
1 Parent(s): cc9f775

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -36
app.py CHANGED
@@ -140,12 +140,9 @@ class NetworkGenerator:
140
  self.active_positions = None
141
 
142
  def calculate_defaults(self):
143
- """Helper to return what the defaults WOULD be for this config."""
144
- # Nodes
145
  total_possible = (self.width + 1) * (self.height + 1)
146
  scale = {"highly_connected": 1.2, "bottlenecks": 0.85, "linear": 0.75}.get(self.topology, 1.0)
147
 
148
- # Effective void fraction logic
149
  if self.topology == "highly_connected": vf = max(0.0, self.node_drop_fraction * 0.8)
150
  elif self.topology == "linear": vf = min(0.95, self.node_drop_fraction * 1.2)
151
  else: vf = self.node_drop_fraction
@@ -153,11 +150,9 @@ class NetworkGenerator:
153
  active_pct = 1.0 - vf
154
  est_nodes = int(self.node_factor * scale * total_possible * active_pct)
155
 
156
- # Edges
157
  if self.topology == "highly_connected": est_edges = int(3.5 * est_nodes)
158
  elif self.topology == "bottlenecks": est_edges = int(1.8 * est_nodes)
159
  else: est_edges = int(1.5 * est_nodes)
160
-
161
  return est_nodes, est_edges
162
 
163
  def generate(self):
@@ -224,21 +219,13 @@ class NetworkGenerator:
224
  self.graph.add_node(tuple(seed))
225
 
226
  def _add_nodes(self):
227
- # Improved: Even with target count, try to respect topology distribution
228
  if self.target_nodes > 0:
229
  needed = self.target_nodes - len(self.graph.nodes())
230
  if needed <= 0: return
231
-
232
  available = list(self.active_positions - set(self.graph.nodes()))
233
-
234
- # If bottleneck/highly connected, prefer clustering vs random scatter
235
  if self.topology != "linear" and len(available) > needed:
236
- # Pick a random center and grow out
237
  center = random.choice(list(self.graph.nodes()))
238
  available.sort(key=lambda n: (n[0]-center[0])**2 + (n[1]-center[1])**2)
239
-
240
- # Take closest needed? No, that makes one big blob.
241
- # Take random subset of available to avoid lines?
242
  chosen = random.sample(available, needed)
243
  for n in chosen: self.graph.add_node(n)
244
  else:
@@ -401,40 +388,29 @@ class NetworkGenerator:
401
  if not self.graph.has_edge(*e1) or not self.graph.has_edge(*e2): continue
402
  l1 = (e1[0][0]-e1[1][0])**2 + (e1[0][1]-e1[1][1])**2
403
  l2 = (e2[0][0]-e2[1][0])**2 + (e2[0][1]-e2[1][1])**2
404
- # FIX: remove_edge takes u,v. If rem is tuple, use *rem
405
  rem = e1 if l1 > l2 else e2
406
  self.graph.remove_edge(*rem)
407
 
408
  def _adjust_edges_to_target(self):
409
  current_edges = list(self.graph.edges())
410
  curr_count = len(current_edges)
411
-
412
  if curr_count > self.target_edges:
413
  to_remove = curr_count - self.target_edges
414
- # Remove edges that are longest first, but preserve "cluster" integrity if possible
415
  sorted_edges = sorted(current_edges, key=lambda e: (e[0][0]-e[1][0])**2 + (e[0][1]-e[1][1])**2, reverse=True)
416
  for e in sorted_edges:
417
  if len(self.graph.edges()) <= self.target_edges: break
418
  self.graph.remove_edge(*e)
419
  if not nx.is_connected(self.graph): self.graph.add_edge(*e)
420
-
421
  elif curr_count < self.target_edges:
422
  needed = self.target_edges - curr_count
423
  nodes = list(self.graph.nodes())
424
  attempts = 0
425
-
426
- # IMPROVED STRATEGY FOR ADDING EDGES TO TARGET
427
- # Instead of random (u,v), preferentially pick u,v that are close
428
- # or in same cluster to avoid "linear" lines across map
429
  while len(self.graph.edges()) < self.target_edges and attempts < (needed * 30):
430
  attempts += 1
431
  u = random.choice(nodes)
432
- # Pick v close to u
433
  candidates = sorted(nodes, key=lambda n: (n[0]-u[0])**2 + (n[1]-u[1])**2)
434
- # Skip self, pick from nearest 10
435
  if len(candidates) < 2: continue
436
  v = random.choice(candidates[1:min(len(candidates), 10)])
437
-
438
  if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
439
  self.graph.add_edge(u, v)
440
 
@@ -461,6 +437,9 @@ class NetworkGenerator:
461
 
462
  # === MANUAL EDITING ===
463
  def manual_add_node(self, x, y):
 
 
 
464
  if not (0 <= x <= self.width and 0 <= y <= self.height): return False, "Out of bounds."
465
  if self.graph.has_node((x, y)): return False, "Already exists."
466
  self.graph.add_node((x, y))
@@ -487,14 +466,26 @@ class NetworkGenerator:
487
  # ==========================================
488
  # GRADIO HELPERS
489
  # ==========================================
490
- def plot_graph(graph, width, height, title="Network"):
491
  fig, ax = plt.subplots(figsize=(8, 8))
492
  pos = {node: (node[0], node[1]) for node in graph.nodes()}
 
 
493
  nx.draw_networkx_edges(graph, pos, ax=ax, width=2, alpha=0.6, edge_color="#333")
494
- nx.draw_networkx_nodes(graph, pos, ax=ax, node_size=350, node_color="#4F46E5", edgecolors="white", linewidths=1.5)
 
 
 
 
 
 
 
 
 
495
  sorted_nodes = get_sorted_nodes(graph)
496
  labels = {node: str(i+1) for i, node in enumerate(sorted_nodes)}
497
  nx.draw_networkx_labels(graph, pos, labels, ax=ax, font_size=8, font_color="white", font_weight="bold")
 
498
  ax.set_xlim(-1, width + 1)
499
  ax.set_ylim(-1, height + 1)
500
  ax.invert_yaxis()
@@ -513,26 +504,19 @@ def get_preset_dims(preset_mode, topology):
513
  return gr.update(value=dims[0], interactive=False), gr.update(value=dims[1], interactive=False)
514
 
515
  def update_ui_for_variant(variant, width, height, topology, void_frac):
516
- """Handles unlocking UI and Pre-filling Defaults."""
517
  is_custom = (variant == "Custom")
518
-
519
  if is_custom:
520
- # Calculate what the defaults WOULD be for this config
521
- # so user starts editing from a reasonable number, not 0
522
  temp_gen = NetworkGenerator(width, height, "F", topology, void_frac)
523
  def_nodes, def_edges = temp_gen.calculate_defaults()
524
-
525
  void_update = gr.update(interactive=True)
526
  target_node_update = gr.update(value=def_nodes, interactive=True)
527
  target_edge_update = gr.update(value=def_edges, interactive=True)
528
  else:
529
- # Fixed Mode
530
  area = width * height
531
  val = 0.60 if area <= 20 else 0.35
532
  void_update = gr.update(value=val, interactive=False)
533
  target_node_update = gr.update(value=0, interactive=False)
534
  target_edge_update = gr.update(value=0, interactive=False)
535
-
536
  return void_update, target_node_update, target_edge_update
537
 
538
  def save_single_visual_action(state_data):
@@ -564,12 +548,20 @@ def manual_edit_action(action, x, y, node_id, state_data):
564
  if not state_data or "graph" not in state_data: return None, "No graph.", state_data
565
  gen = NetworkGenerator(state_data["width"], state_data["height"])
566
  gen.graph = state_data["graph"]
 
 
 
 
567
  if action == "Add Node":
568
- success, msg = gen.manual_add_node(int(x), int(y))
 
 
 
569
  else:
570
  success, msg = gen.manual_delete_node_by_id(node_id)
 
571
  if success:
572
- fig = plot_graph(gen.graph, state_data["width"], state_data["height"], "Edited")
573
  metrics = f"**Nodes:** {len(gen.graph.nodes())} | **Edges:** {len(gen.graph.edges())} | {msg}"
574
  state_data["graph"] = gen.graph
575
  return fig, metrics, state_data
@@ -643,7 +635,6 @@ with gr.Blocks(title="Graph Generator Pro") as demo:
643
  preset.change(get_preset_dims, inputs_dims, [width, height])
644
  topology.change(get_preset_dims, inputs_dims, [width, height])
645
 
646
- # Updated: Now includes topology for default calculation
647
  inputs_var = [variant, width, height, topology, void_frac]
648
  variant.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
649
  width.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
 
140
  self.active_positions = None
141
 
142
  def calculate_defaults(self):
 
 
143
  total_possible = (self.width + 1) * (self.height + 1)
144
  scale = {"highly_connected": 1.2, "bottlenecks": 0.85, "linear": 0.75}.get(self.topology, 1.0)
145
 
 
146
  if self.topology == "highly_connected": vf = max(0.0, self.node_drop_fraction * 0.8)
147
  elif self.topology == "linear": vf = min(0.95, self.node_drop_fraction * 1.2)
148
  else: vf = self.node_drop_fraction
 
150
  active_pct = 1.0 - vf
151
  est_nodes = int(self.node_factor * scale * total_possible * active_pct)
152
 
 
153
  if self.topology == "highly_connected": est_edges = int(3.5 * est_nodes)
154
  elif self.topology == "bottlenecks": est_edges = int(1.8 * est_nodes)
155
  else: est_edges = int(1.5 * est_nodes)
 
156
  return est_nodes, est_edges
157
 
158
  def generate(self):
 
219
  self.graph.add_node(tuple(seed))
220
 
221
  def _add_nodes(self):
 
222
  if self.target_nodes > 0:
223
  needed = self.target_nodes - len(self.graph.nodes())
224
  if needed <= 0: return
 
225
  available = list(self.active_positions - set(self.graph.nodes()))
 
 
226
  if self.topology != "linear" and len(available) > needed:
 
227
  center = random.choice(list(self.graph.nodes()))
228
  available.sort(key=lambda n: (n[0]-center[0])**2 + (n[1]-center[1])**2)
 
 
 
229
  chosen = random.sample(available, needed)
230
  for n in chosen: self.graph.add_node(n)
231
  else:
 
388
  if not self.graph.has_edge(*e1) or not self.graph.has_edge(*e2): continue
389
  l1 = (e1[0][0]-e1[1][0])**2 + (e1[0][1]-e1[1][1])**2
390
  l2 = (e2[0][0]-e2[1][0])**2 + (e2[0][1]-e2[1][1])**2
 
391
  rem = e1 if l1 > l2 else e2
392
  self.graph.remove_edge(*rem)
393
 
394
  def _adjust_edges_to_target(self):
395
  current_edges = list(self.graph.edges())
396
  curr_count = len(current_edges)
 
397
  if curr_count > self.target_edges:
398
  to_remove = curr_count - self.target_edges
 
399
  sorted_edges = sorted(current_edges, key=lambda e: (e[0][0]-e[1][0])**2 + (e[0][1]-e[1][1])**2, reverse=True)
400
  for e in sorted_edges:
401
  if len(self.graph.edges()) <= self.target_edges: break
402
  self.graph.remove_edge(*e)
403
  if not nx.is_connected(self.graph): self.graph.add_edge(*e)
 
404
  elif curr_count < self.target_edges:
405
  needed = self.target_edges - curr_count
406
  nodes = list(self.graph.nodes())
407
  attempts = 0
 
 
 
 
408
  while len(self.graph.edges()) < self.target_edges and attempts < (needed * 30):
409
  attempts += 1
410
  u = random.choice(nodes)
 
411
  candidates = sorted(nodes, key=lambda n: (n[0]-u[0])**2 + (n[1]-u[1])**2)
 
412
  if len(candidates) < 2: continue
413
  v = random.choice(candidates[1:min(len(candidates), 10)])
 
414
  if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
415
  self.graph.add_edge(u, v)
416
 
 
437
 
438
  # === MANUAL EDITING ===
439
  def manual_add_node(self, x, y):
440
+ # FIX: Force Int Cast to avoid "Already Exists" due to float mismatch
441
+ x, y = int(x), int(y)
442
+
443
  if not (0 <= x <= self.width and 0 <= y <= self.height): return False, "Out of bounds."
444
  if self.graph.has_node((x, y)): return False, "Already exists."
445
  self.graph.add_node((x, y))
 
466
  # ==========================================
467
  # GRADIO HELPERS
468
  # ==========================================
469
+ def plot_graph(graph, width, height, title="Network", highlight_node=None):
470
  fig, ax = plt.subplots(figsize=(8, 8))
471
  pos = {node: (node[0], node[1]) for node in graph.nodes()}
472
+
473
+ # 1. Edges
474
  nx.draw_networkx_edges(graph, pos, ax=ax, width=2, alpha=0.6, edge_color="#333")
475
+
476
+ # 2. Nodes (Standard)
477
+ # Filter nodes that are NOT highlighted
478
+ normal_nodes = [n for n in graph.nodes() if n != highlight_node]
479
+ nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=normal_nodes, node_size=350, node_color="#4F46E5", edgecolors="white", linewidths=1.5)
480
+
481
+ # 3. Nodes (Highlight)
482
+ if highlight_node and graph.has_node(highlight_node):
483
+ nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=[highlight_node], node_size=400, node_color="#EF4444", edgecolors="white", linewidths=2.0)
484
+
485
  sorted_nodes = get_sorted_nodes(graph)
486
  labels = {node: str(i+1) for i, node in enumerate(sorted_nodes)}
487
  nx.draw_networkx_labels(graph, pos, labels, ax=ax, font_size=8, font_color="white", font_weight="bold")
488
+
489
  ax.set_xlim(-1, width + 1)
490
  ax.set_ylim(-1, height + 1)
491
  ax.invert_yaxis()
 
504
  return gr.update(value=dims[0], interactive=False), gr.update(value=dims[1], interactive=False)
505
 
506
  def update_ui_for_variant(variant, width, height, topology, void_frac):
 
507
  is_custom = (variant == "Custom")
 
508
  if is_custom:
 
 
509
  temp_gen = NetworkGenerator(width, height, "F", topology, void_frac)
510
  def_nodes, def_edges = temp_gen.calculate_defaults()
 
511
  void_update = gr.update(interactive=True)
512
  target_node_update = gr.update(value=def_nodes, interactive=True)
513
  target_edge_update = gr.update(value=def_edges, interactive=True)
514
  else:
 
515
  area = width * height
516
  val = 0.60 if area <= 20 else 0.35
517
  void_update = gr.update(value=val, interactive=False)
518
  target_node_update = gr.update(value=0, interactive=False)
519
  target_edge_update = gr.update(value=0, interactive=False)
 
520
  return void_update, target_node_update, target_edge_update
521
 
522
  def save_single_visual_action(state_data):
 
548
  if not state_data or "graph" not in state_data: return None, "No graph.", state_data
549
  gen = NetworkGenerator(state_data["width"], state_data["height"])
550
  gen.graph = state_data["graph"]
551
+
552
+ # Store added node to pass to plotter
553
+ highlight = None
554
+
555
  if action == "Add Node":
556
+ # Ensure Int here too
557
+ x, y = int(x), int(y)
558
+ success, msg = gen.manual_add_node(x, y)
559
+ if success: highlight = (x, y)
560
  else:
561
  success, msg = gen.manual_delete_node_by_id(node_id)
562
+
563
  if success:
564
+ fig = plot_graph(gen.graph, state_data["width"], state_data["height"], "Edited", highlight_node=highlight)
565
  metrics = f"**Nodes:** {len(gen.graph.nodes())} | **Edges:** {len(gen.graph.edges())} | {msg}"
566
  state_data["graph"] = gen.graph
567
  return fig, metrics, state_data
 
635
  preset.change(get_preset_dims, inputs_dims, [width, height])
636
  topology.change(get_preset_dims, inputs_dims, [width, height])
637
 
 
638
  inputs_var = [variant, width, height, topology, void_frac]
639
  variant.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
640
  width.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])