Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -140,12 +140,9 @@ class NetworkGenerator:
|
|
| 140 |
self.active_positions = None
|
| 141 |
|
| 142 |
def calculate_defaults(self):
|
| 143 |
-
"""Helper to return what the defaults WOULD be for this config."""
|
| 144 |
-
# Nodes
|
| 145 |
total_possible = (self.width + 1) * (self.height + 1)
|
| 146 |
scale = {"highly_connected": 1.2, "bottlenecks": 0.85, "linear": 0.75}.get(self.topology, 1.0)
|
| 147 |
|
| 148 |
-
# Effective void fraction logic
|
| 149 |
if self.topology == "highly_connected": vf = max(0.0, self.node_drop_fraction * 0.8)
|
| 150 |
elif self.topology == "linear": vf = min(0.95, self.node_drop_fraction * 1.2)
|
| 151 |
else: vf = self.node_drop_fraction
|
|
@@ -153,11 +150,9 @@ class NetworkGenerator:
|
|
| 153 |
active_pct = 1.0 - vf
|
| 154 |
est_nodes = int(self.node_factor * scale * total_possible * active_pct)
|
| 155 |
|
| 156 |
-
# Edges
|
| 157 |
if self.topology == "highly_connected": est_edges = int(3.5 * est_nodes)
|
| 158 |
elif self.topology == "bottlenecks": est_edges = int(1.8 * est_nodes)
|
| 159 |
else: est_edges = int(1.5 * est_nodes)
|
| 160 |
-
|
| 161 |
return est_nodes, est_edges
|
| 162 |
|
| 163 |
def generate(self):
|
|
@@ -224,21 +219,13 @@ class NetworkGenerator:
|
|
| 224 |
self.graph.add_node(tuple(seed))
|
| 225 |
|
| 226 |
def _add_nodes(self):
|
| 227 |
-
# Improved: Even with target count, try to respect topology distribution
|
| 228 |
if self.target_nodes > 0:
|
| 229 |
needed = self.target_nodes - len(self.graph.nodes())
|
| 230 |
if needed <= 0: return
|
| 231 |
-
|
| 232 |
available = list(self.active_positions - set(self.graph.nodes()))
|
| 233 |
-
|
| 234 |
-
# If bottleneck/highly connected, prefer clustering vs random scatter
|
| 235 |
if self.topology != "linear" and len(available) > needed:
|
| 236 |
-
# Pick a random center and grow out
|
| 237 |
center = random.choice(list(self.graph.nodes()))
|
| 238 |
available.sort(key=lambda n: (n[0]-center[0])**2 + (n[1]-center[1])**2)
|
| 239 |
-
|
| 240 |
-
# Take closest needed? No, that makes one big blob.
|
| 241 |
-
# Take random subset of available to avoid lines?
|
| 242 |
chosen = random.sample(available, needed)
|
| 243 |
for n in chosen: self.graph.add_node(n)
|
| 244 |
else:
|
|
@@ -401,40 +388,29 @@ class NetworkGenerator:
|
|
| 401 |
if not self.graph.has_edge(*e1) or not self.graph.has_edge(*e2): continue
|
| 402 |
l1 = (e1[0][0]-e1[1][0])**2 + (e1[0][1]-e1[1][1])**2
|
| 403 |
l2 = (e2[0][0]-e2[1][0])**2 + (e2[0][1]-e2[1][1])**2
|
| 404 |
-
# FIX: remove_edge takes u,v. If rem is tuple, use *rem
|
| 405 |
rem = e1 if l1 > l2 else e2
|
| 406 |
self.graph.remove_edge(*rem)
|
| 407 |
|
| 408 |
def _adjust_edges_to_target(self):
|
| 409 |
current_edges = list(self.graph.edges())
|
| 410 |
curr_count = len(current_edges)
|
| 411 |
-
|
| 412 |
if curr_count > self.target_edges:
|
| 413 |
to_remove = curr_count - self.target_edges
|
| 414 |
-
# Remove edges that are longest first, but preserve "cluster" integrity if possible
|
| 415 |
sorted_edges = sorted(current_edges, key=lambda e: (e[0][0]-e[1][0])**2 + (e[0][1]-e[1][1])**2, reverse=True)
|
| 416 |
for e in sorted_edges:
|
| 417 |
if len(self.graph.edges()) <= self.target_edges: break
|
| 418 |
self.graph.remove_edge(*e)
|
| 419 |
if not nx.is_connected(self.graph): self.graph.add_edge(*e)
|
| 420 |
-
|
| 421 |
elif curr_count < self.target_edges:
|
| 422 |
needed = self.target_edges - curr_count
|
| 423 |
nodes = list(self.graph.nodes())
|
| 424 |
attempts = 0
|
| 425 |
-
|
| 426 |
-
# IMPROVED STRATEGY FOR ADDING EDGES TO TARGET
|
| 427 |
-
# Instead of random (u,v), preferentially pick u,v that are close
|
| 428 |
-
# or in same cluster to avoid "linear" lines across map
|
| 429 |
while len(self.graph.edges()) < self.target_edges and attempts < (needed * 30):
|
| 430 |
attempts += 1
|
| 431 |
u = random.choice(nodes)
|
| 432 |
-
# Pick v close to u
|
| 433 |
candidates = sorted(nodes, key=lambda n: (n[0]-u[0])**2 + (n[1]-u[1])**2)
|
| 434 |
-
# Skip self, pick from nearest 10
|
| 435 |
if len(candidates) < 2: continue
|
| 436 |
v = random.choice(candidates[1:min(len(candidates), 10)])
|
| 437 |
-
|
| 438 |
if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
|
| 439 |
self.graph.add_edge(u, v)
|
| 440 |
|
|
@@ -461,6 +437,9 @@ class NetworkGenerator:
|
|
| 461 |
|
| 462 |
# === MANUAL EDITING ===
|
| 463 |
def manual_add_node(self, x, y):
|
|
|
|
|
|
|
|
|
|
| 464 |
if not (0 <= x <= self.width and 0 <= y <= self.height): return False, "Out of bounds."
|
| 465 |
if self.graph.has_node((x, y)): return False, "Already exists."
|
| 466 |
self.graph.add_node((x, y))
|
|
@@ -487,14 +466,26 @@ class NetworkGenerator:
|
|
| 487 |
# ==========================================
|
| 488 |
# GRADIO HELPERS
|
| 489 |
# ==========================================
|
| 490 |
-
def plot_graph(graph, width, height, title="Network"):
|
| 491 |
fig, ax = plt.subplots(figsize=(8, 8))
|
| 492 |
pos = {node: (node[0], node[1]) for node in graph.nodes()}
|
|
|
|
|
|
|
| 493 |
nx.draw_networkx_edges(graph, pos, ax=ax, width=2, alpha=0.6, edge_color="#333")
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
sorted_nodes = get_sorted_nodes(graph)
|
| 496 |
labels = {node: str(i+1) for i, node in enumerate(sorted_nodes)}
|
| 497 |
nx.draw_networkx_labels(graph, pos, labels, ax=ax, font_size=8, font_color="white", font_weight="bold")
|
|
|
|
| 498 |
ax.set_xlim(-1, width + 1)
|
| 499 |
ax.set_ylim(-1, height + 1)
|
| 500 |
ax.invert_yaxis()
|
|
@@ -513,26 +504,19 @@ def get_preset_dims(preset_mode, topology):
|
|
| 513 |
return gr.update(value=dims[0], interactive=False), gr.update(value=dims[1], interactive=False)
|
| 514 |
|
| 515 |
def update_ui_for_variant(variant, width, height, topology, void_frac):
|
| 516 |
-
"""Handles unlocking UI and Pre-filling Defaults."""
|
| 517 |
is_custom = (variant == "Custom")
|
| 518 |
-
|
| 519 |
if is_custom:
|
| 520 |
-
# Calculate what the defaults WOULD be for this config
|
| 521 |
-
# so user starts editing from a reasonable number, not 0
|
| 522 |
temp_gen = NetworkGenerator(width, height, "F", topology, void_frac)
|
| 523 |
def_nodes, def_edges = temp_gen.calculate_defaults()
|
| 524 |
-
|
| 525 |
void_update = gr.update(interactive=True)
|
| 526 |
target_node_update = gr.update(value=def_nodes, interactive=True)
|
| 527 |
target_edge_update = gr.update(value=def_edges, interactive=True)
|
| 528 |
else:
|
| 529 |
-
# Fixed Mode
|
| 530 |
area = width * height
|
| 531 |
val = 0.60 if area <= 20 else 0.35
|
| 532 |
void_update = gr.update(value=val, interactive=False)
|
| 533 |
target_node_update = gr.update(value=0, interactive=False)
|
| 534 |
target_edge_update = gr.update(value=0, interactive=False)
|
| 535 |
-
|
| 536 |
return void_update, target_node_update, target_edge_update
|
| 537 |
|
| 538 |
def save_single_visual_action(state_data):
|
|
@@ -564,12 +548,20 @@ def manual_edit_action(action, x, y, node_id, state_data):
|
|
| 564 |
if not state_data or "graph" not in state_data: return None, "No graph.", state_data
|
| 565 |
gen = NetworkGenerator(state_data["width"], state_data["height"])
|
| 566 |
gen.graph = state_data["graph"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
if action == "Add Node":
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
| 569 |
else:
|
| 570 |
success, msg = gen.manual_delete_node_by_id(node_id)
|
|
|
|
| 571 |
if success:
|
| 572 |
-
fig = plot_graph(gen.graph, state_data["width"], state_data["height"], "Edited")
|
| 573 |
metrics = f"**Nodes:** {len(gen.graph.nodes())} | **Edges:** {len(gen.graph.edges())} | {msg}"
|
| 574 |
state_data["graph"] = gen.graph
|
| 575 |
return fig, metrics, state_data
|
|
@@ -643,7 +635,6 @@ with gr.Blocks(title="Graph Generator Pro") as demo:
|
|
| 643 |
preset.change(get_preset_dims, inputs_dims, [width, height])
|
| 644 |
topology.change(get_preset_dims, inputs_dims, [width, height])
|
| 645 |
|
| 646 |
-
# Updated: Now includes topology for default calculation
|
| 647 |
inputs_var = [variant, width, height, topology, void_frac]
|
| 648 |
variant.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
|
| 649 |
width.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
|
|
|
|
| 140 |
self.active_positions = None
|
| 141 |
|
| 142 |
def calculate_defaults(self):
|
|
|
|
|
|
|
| 143 |
total_possible = (self.width + 1) * (self.height + 1)
|
| 144 |
scale = {"highly_connected": 1.2, "bottlenecks": 0.85, "linear": 0.75}.get(self.topology, 1.0)
|
| 145 |
|
|
|
|
| 146 |
if self.topology == "highly_connected": vf = max(0.0, self.node_drop_fraction * 0.8)
|
| 147 |
elif self.topology == "linear": vf = min(0.95, self.node_drop_fraction * 1.2)
|
| 148 |
else: vf = self.node_drop_fraction
|
|
|
|
| 150 |
active_pct = 1.0 - vf
|
| 151 |
est_nodes = int(self.node_factor * scale * total_possible * active_pct)
|
| 152 |
|
|
|
|
| 153 |
if self.topology == "highly_connected": est_edges = int(3.5 * est_nodes)
|
| 154 |
elif self.topology == "bottlenecks": est_edges = int(1.8 * est_nodes)
|
| 155 |
else: est_edges = int(1.5 * est_nodes)
|
|
|
|
| 156 |
return est_nodes, est_edges
|
| 157 |
|
| 158 |
def generate(self):
|
|
|
|
| 219 |
self.graph.add_node(tuple(seed))
|
| 220 |
|
| 221 |
def _add_nodes(self):
|
|
|
|
| 222 |
if self.target_nodes > 0:
|
| 223 |
needed = self.target_nodes - len(self.graph.nodes())
|
| 224 |
if needed <= 0: return
|
|
|
|
| 225 |
available = list(self.active_positions - set(self.graph.nodes()))
|
|
|
|
|
|
|
| 226 |
if self.topology != "linear" and len(available) > needed:
|
|
|
|
| 227 |
center = random.choice(list(self.graph.nodes()))
|
| 228 |
available.sort(key=lambda n: (n[0]-center[0])**2 + (n[1]-center[1])**2)
|
|
|
|
|
|
|
|
|
|
| 229 |
chosen = random.sample(available, needed)
|
| 230 |
for n in chosen: self.graph.add_node(n)
|
| 231 |
else:
|
|
|
|
| 388 |
if not self.graph.has_edge(*e1) or not self.graph.has_edge(*e2): continue
|
| 389 |
l1 = (e1[0][0]-e1[1][0])**2 + (e1[0][1]-e1[1][1])**2
|
| 390 |
l2 = (e2[0][0]-e2[1][0])**2 + (e2[0][1]-e2[1][1])**2
|
|
|
|
| 391 |
rem = e1 if l1 > l2 else e2
|
| 392 |
self.graph.remove_edge(*rem)
|
| 393 |
|
| 394 |
def _adjust_edges_to_target(self):
|
| 395 |
current_edges = list(self.graph.edges())
|
| 396 |
curr_count = len(current_edges)
|
|
|
|
| 397 |
if curr_count > self.target_edges:
|
| 398 |
to_remove = curr_count - self.target_edges
|
|
|
|
| 399 |
sorted_edges = sorted(current_edges, key=lambda e: (e[0][0]-e[1][0])**2 + (e[0][1]-e[1][1])**2, reverse=True)
|
| 400 |
for e in sorted_edges:
|
| 401 |
if len(self.graph.edges()) <= self.target_edges: break
|
| 402 |
self.graph.remove_edge(*e)
|
| 403 |
if not nx.is_connected(self.graph): self.graph.add_edge(*e)
|
|
|
|
| 404 |
elif curr_count < self.target_edges:
|
| 405 |
needed = self.target_edges - curr_count
|
| 406 |
nodes = list(self.graph.nodes())
|
| 407 |
attempts = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
while len(self.graph.edges()) < self.target_edges and attempts < (needed * 30):
|
| 409 |
attempts += 1
|
| 410 |
u = random.choice(nodes)
|
|
|
|
| 411 |
candidates = sorted(nodes, key=lambda n: (n[0]-u[0])**2 + (n[1]-u[1])**2)
|
|
|
|
| 412 |
if len(candidates) < 2: continue
|
| 413 |
v = random.choice(candidates[1:min(len(candidates), 10)])
|
|
|
|
| 414 |
if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
|
| 415 |
self.graph.add_edge(u, v)
|
| 416 |
|
|
|
|
| 437 |
|
| 438 |
# === MANUAL EDITING ===
|
| 439 |
def manual_add_node(self, x, y):
|
| 440 |
+
# FIX: Force Int Cast to avoid "Already Exists" due to float mismatch
|
| 441 |
+
x, y = int(x), int(y)
|
| 442 |
+
|
| 443 |
if not (0 <= x <= self.width and 0 <= y <= self.height): return False, "Out of bounds."
|
| 444 |
if self.graph.has_node((x, y)): return False, "Already exists."
|
| 445 |
self.graph.add_node((x, y))
|
|
|
|
| 466 |
# ==========================================
|
| 467 |
# GRADIO HELPERS
|
| 468 |
# ==========================================
|
| 469 |
+
def plot_graph(graph, width, height, title="Network", highlight_node=None):
|
| 470 |
fig, ax = plt.subplots(figsize=(8, 8))
|
| 471 |
pos = {node: (node[0], node[1]) for node in graph.nodes()}
|
| 472 |
+
|
| 473 |
+
# 1. Edges
|
| 474 |
nx.draw_networkx_edges(graph, pos, ax=ax, width=2, alpha=0.6, edge_color="#333")
|
| 475 |
+
|
| 476 |
+
# 2. Nodes (Standard)
|
| 477 |
+
# Filter nodes that are NOT highlighted
|
| 478 |
+
normal_nodes = [n for n in graph.nodes() if n != highlight_node]
|
| 479 |
+
nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=normal_nodes, node_size=350, node_color="#4F46E5", edgecolors="white", linewidths=1.5)
|
| 480 |
+
|
| 481 |
+
# 3. Nodes (Highlight)
|
| 482 |
+
if highlight_node and graph.has_node(highlight_node):
|
| 483 |
+
nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=[highlight_node], node_size=400, node_color="#EF4444", edgecolors="white", linewidths=2.0)
|
| 484 |
+
|
| 485 |
sorted_nodes = get_sorted_nodes(graph)
|
| 486 |
labels = {node: str(i+1) for i, node in enumerate(sorted_nodes)}
|
| 487 |
nx.draw_networkx_labels(graph, pos, labels, ax=ax, font_size=8, font_color="white", font_weight="bold")
|
| 488 |
+
|
| 489 |
ax.set_xlim(-1, width + 1)
|
| 490 |
ax.set_ylim(-1, height + 1)
|
| 491 |
ax.invert_yaxis()
|
|
|
|
| 504 |
return gr.update(value=dims[0], interactive=False), gr.update(value=dims[1], interactive=False)
|
| 505 |
|
| 506 |
def update_ui_for_variant(variant, width, height, topology, void_frac):
|
|
|
|
| 507 |
is_custom = (variant == "Custom")
|
|
|
|
| 508 |
if is_custom:
|
|
|
|
|
|
|
| 509 |
temp_gen = NetworkGenerator(width, height, "F", topology, void_frac)
|
| 510 |
def_nodes, def_edges = temp_gen.calculate_defaults()
|
|
|
|
| 511 |
void_update = gr.update(interactive=True)
|
| 512 |
target_node_update = gr.update(value=def_nodes, interactive=True)
|
| 513 |
target_edge_update = gr.update(value=def_edges, interactive=True)
|
| 514 |
else:
|
|
|
|
| 515 |
area = width * height
|
| 516 |
val = 0.60 if area <= 20 else 0.35
|
| 517 |
void_update = gr.update(value=val, interactive=False)
|
| 518 |
target_node_update = gr.update(value=0, interactive=False)
|
| 519 |
target_edge_update = gr.update(value=0, interactive=False)
|
|
|
|
| 520 |
return void_update, target_node_update, target_edge_update
|
| 521 |
|
| 522 |
def save_single_visual_action(state_data):
|
|
|
|
| 548 |
if not state_data or "graph" not in state_data: return None, "No graph.", state_data
|
| 549 |
gen = NetworkGenerator(state_data["width"], state_data["height"])
|
| 550 |
gen.graph = state_data["graph"]
|
| 551 |
+
|
| 552 |
+
# Store added node to pass to plotter
|
| 553 |
+
highlight = None
|
| 554 |
+
|
| 555 |
if action == "Add Node":
|
| 556 |
+
# Ensure Int here too
|
| 557 |
+
x, y = int(x), int(y)
|
| 558 |
+
success, msg = gen.manual_add_node(x, y)
|
| 559 |
+
if success: highlight = (x, y)
|
| 560 |
else:
|
| 561 |
success, msg = gen.manual_delete_node_by_id(node_id)
|
| 562 |
+
|
| 563 |
if success:
|
| 564 |
+
fig = plot_graph(gen.graph, state_data["width"], state_data["height"], "Edited", highlight_node=highlight)
|
| 565 |
metrics = f"**Nodes:** {len(gen.graph.nodes())} | **Edges:** {len(gen.graph.edges())} | {msg}"
|
| 566 |
state_data["graph"] = gen.graph
|
| 567 |
return fig, metrics, state_data
|
|
|
|
| 635 |
preset.change(get_preset_dims, inputs_dims, [width, height])
|
| 636 |
topology.change(get_preset_dims, inputs_dims, [width, height])
|
| 637 |
|
|
|
|
| 638 |
inputs_var = [variant, width, height, topology, void_frac]
|
| 639 |
variant.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
|
| 640 |
width.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
|