TahaRasouli's picture
Rename app.py to app3.py
c0a606c verified
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
import random
import time
import json
import os
import shutil
from datetime import datetime
# ==========================================
# 1. JSON & HELPER LOGIC
# ==========================================
def get_sorted_nodes(G):
"""Returns nodes sorted by X, then Y. Matches JSON ID ordering."""
return sorted(list(G.nodes()), key=lambda l: (l[0], l[1]))
def prepare_edges_for_json(G):
nodes_list = get_sorted_nodes(G)
nodes_list_dict = {}
I = []
for idx, node in enumerate(nodes_list):
s_id = str(idx + 1)
I.append(s_id)
nodes_list_dict[s_id] = node
coord_to_id = {v: k for k, v in nodes_list_dict.items()}
edges_list = list(G.edges())
edges_formatted = []
for u, v in edges_list:
if u in coord_to_id and v in coord_to_id:
edges_formatted.append({
"room1": coord_to_id[u],
"room2": coord_to_id[v]
})
return edges_formatted, I, nodes_list_dict
def prepare_parameter_for_json(G, I, nodes_list_dict):
n_count = len(G.nodes())
weights = []
for i in range(n_count):
val = n_count / (n_count * (1 + (((i + 1) * 2) / 30)))
weights.append(val)
m_weights = random.choices(I, weights=weights, k=5)
t_weights_probs = []
for i in range(10):
val = n_count / (n_count * (1 + (((i + 1) * 2) / 5)))
t_weights_probs.append(val)
t_weights = random.choices(range(1, 11), weights=t_weights_probs, k=5)
dismantled = []
conditioningDuration = []
assignment = []
help_list = []
for m in range(5):
dismantled.append({"m": str(m + 1), "i": str(m_weights[m]), "t": t_weights[m], "value": 1})
conditioningDuration.append({"m": str(m + 1), "value": 1})
x = random.randint(1, 3)
if m > 2:
if 1 not in help_list: x = 1
if 2 not in help_list: x = 2
if 3 not in help_list: x = 3
help_list.append(x)
assignment.append({"m": str(m + 1), "r": str(x), "value": 1})
t_weights_del = random.choices(range(1, 11), weights=t_weights_probs[:10], k=3)
delivered = []
conditioningCapacity = []
for r in range(3):
delivered.append({"r": str(r + 1), "i": "1", "t": t_weights_del[r], "value": 1})
conditioningCapacity.append({"r": str(r + 1), "value": 1})
CostMT, CostMB, CostRT, CostRB, Coord = [], [], [], [], []
for i in range(n_count):
s_id = str(i + 1)
CostMT.append({"i": s_id, "value": random.choice([2, 5])})
CostMB.append({"i": s_id, "value": random.choice([5, 10, 30])})
CostRT.append({"i": s_id, "value": random.choice([4, 10])})
CostRB.append({"i": s_id, "value": 1000 if i==0 else random.choice([20, 30, 100])})
if s_id in nodes_list_dict:
Coord.append({"i": s_id, "Coordinates": nodes_list_dict[s_id]})
return dismantled, assignment, delivered, conditioningCapacity, conditioningDuration, CostMT, CostMB, CostRT, CostRB, Coord
def generate_full_json_dict(G, loop=0):
edges, I, nodes_list_dict = prepare_edges_for_json(G)
dismantled, assignment, delivered, condCap, condDur, CostMT, CostMB, CostRT, CostRB, Coord = prepare_parameter_for_json(G, I, nodes_list_dict)
sets = {
"I": I,
"E": {"bidirectional": True, "seed": 1, "edges": edges},
"M": ["1", "2", "3", "4", "5"],
"R": ["1", "2", "3"]
}
params = {
"defaults": { "V": 1000, "CostMB": 100, "CostMT": 20, "CostRB": 300, "CostRT": 50 },
"t_max": 100,
"V": [{"m": "1", "i": "1", "value": 42}],
"dismantled": dismantled,
"delivered": delivered,
"conditioningCapacity": condCap,
"conditioningDuration": condDur,
"assignment": assignment,
"CostMT": CostMT, "CostMB": CostMB,
"CostRT": CostRT, "CostRB": CostRB,
"CostZR": 9, "CostZH": 5,
"Coord": Coord
}
return {"description": "Generated by Gradio", "sets": sets, "params": params}
# ==========================================
# 2. NETWORK GENERATOR CLASS
# ==========================================
class NetworkGenerator:
def __init__(self, width=10, height=10, variant="F", topology="highly_connected",
node_drop_fraction=0.1, target_nodes=0, target_edges=0,
bottleneck_cluster_count=None, bottleneck_edges_per_link=1):
self.variant = variant.upper()
self.topology = topology.lower()
self.width = int(width)
self.height = int(height)
self.node_drop_fraction = float(node_drop_fraction)
self.target_nodes = int(target_nodes)
self.target_edges = int(target_edges)
self.node_factor = 0.4
if bottleneck_cluster_count is None:
area = self.width * self.height
self.bottleneck_cluster_count = max(2, int(area / 18))
else:
self.bottleneck_cluster_count = int(bottleneck_cluster_count)
self.bottleneck_edges_per_link = int(bottleneck_edges_per_link)
self.graph = None
self.active_positions = None
def calculate_defaults(self):
total_possible = (self.width + 1) * (self.height + 1)
scale = {"highly_connected": 1.2, "bottlenecks": 0.85, "linear": 0.75}.get(self.topology, 1.0)
if self.topology == "highly_connected": vf = max(0.0, self.node_drop_fraction * 0.8)
elif self.topology == "linear": vf = min(0.95, self.node_drop_fraction * 1.2)
else: vf = self.node_drop_fraction
active_pct = 1.0 - vf
est_nodes = int(self.node_factor * scale * total_possible * active_pct)
if self.topology == "highly_connected": est_edges = int(3.5 * est_nodes)
elif self.topology == "bottlenecks": est_edges = int(1.8 * est_nodes)
else: est_edges = int(1.5 * est_nodes)
return est_nodes, est_edges
def generate(self):
max_attempts = 15
for attempt in range(max_attempts):
self._build_node_mask()
self._initialize_graph()
self._add_nodes()
nodes = list(self.graph.nodes())
if len(nodes) < 2: continue
if self.topology == "bottlenecks":
self._build_bottleneck_clusters(nodes)
else:
self._connect_all_nodes_by_nearby_growth(nodes)
self._add_edges()
self._remove_intersections()
if self.target_edges > 0:
self._adjust_edges_to_target()
else:
self._enforce_edge_budget()
if not nx.is_connected(self.graph):
self._force_connect_components()
self._remove_intersections()
if nx.is_connected(self.graph):
return self.graph
raise RuntimeError("Failed to generate valid network.")
def _effective_node_drop_fraction(self):
if self.target_nodes > 0: return 0.0
base = self.node_drop_fraction
if self.topology == "highly_connected": return max(0.0, base * 0.8)
if self.topology == "linear": return min(0.95, base * 1.2)
return base
def _build_node_mask(self):
all_positions = [(x, y) for x in range(self.width + 1) for y in range(self.height + 1)]
if self.target_nodes > 0:
self.active_positions = set(all_positions)
else:
drop_frac = self._effective_node_drop_fraction()
drop = int(drop_frac * len(all_positions))
deactivated = set(random.sample(all_positions, drop)) if drop > 0 else set()
self.active_positions = set(all_positions) - deactivated
def _initialize_graph(self):
self.graph = nx.Graph()
margin_x = max(1, self.width // 4)
margin_y = max(1, self.height // 4)
low_x, high_x = margin_x, self.width - margin_x
low_y, high_y = margin_y, self.height - margin_y
middle_active = [p for p in self.active_positions if low_x <= p[0] <= high_x and low_y <= p[1] <= high_y]
if middle_active: seed = random.choice(middle_active)
elif self.active_positions: seed = random.choice(list(self.active_positions))
else: return
self.graph.add_node(tuple(seed))
def _add_nodes(self):
if self.target_nodes > 0:
needed = self.target_nodes - len(self.graph.nodes())
if needed <= 0: return
available = list(self.active_positions - set(self.graph.nodes()))
if self.topology != "linear" and len(available) > needed:
center = random.choice(list(self.graph.nodes()))
available.sort(key=lambda n: (n[0]-center[0])**2 + (n[1]-center[1])**2)
chosen = random.sample(available, needed)
for n in chosen: self.graph.add_node(n)
else:
if len(available) < needed:
for n in available: self.graph.add_node(n)
else:
for n in random.sample(available, needed): self.graph.add_node(n)
return
total_possible = (self.width + 1) * (self.height + 1)
base = self.node_factor if self.variant == "F" else random.uniform(0.3, 0.6)
scale = {"highly_connected": 1.2, "bottlenecks": 0.85, "linear": 0.75}.get(self.topology, 1.0)
target = int(base * scale * total_possible)
target = min(target, len(self.active_positions))
attempts = 0
while len(self.graph.nodes()) < target and attempts < (target * 20):
attempts += 1
x = random.randint(0, self.width)
y = random.randint(0, self.height)
if (x, y) in self.active_positions and (x, y) not in self.graph:
self.graph.add_node((x, y))
def _connect_all_nodes_by_nearby_growth(self, nodes):
connected = set()
remaining = set(nodes)
if not remaining: return
current = random.choice(nodes)
connected.add(current)
remaining.remove(current)
while remaining:
candidates = []
for n in remaining:
closest_dist = min([abs(n[0]-c[0]) + abs(n[1]-c[1]) for c in connected])
if closest_dist <= 4:
candidates.append(n)
if not candidates:
best_n = min(remaining, key=lambda r: min(abs(r[0]-c[0]) + abs(r[1]-c[1]) for c in connected))
candidates.append(best_n)
candidate = random.choice(candidates)
neighbors = sorted(list(connected), key=lambda c: abs(c[0]-candidate[0]) + abs(c[1]-candidate[1]))
for n in neighbors[:3]:
if not self._would_create_intersection(n, candidate):
self.graph.add_edge(n, candidate)
break
else:
self.graph.add_edge(neighbors[0], candidate)
connected.add(candidate)
remaining.remove(candidate)
def _compute_edge_count(self):
if self.target_edges > 0: return self.target_edges
n = len(self.graph.nodes())
if self.topology == "highly_connected": return int(3.5 * n)
if self.topology == "bottlenecks": return int(1.8 * n)
return int(random.uniform(1.2, 2.0) * n)
def _add_edges(self):
nodes = list(self.graph.nodes())
if self.topology == "highly_connected": self._add_cluster_dense(nodes, self._compute_edge_count())
elif self.topology == "linear": self._make_linear(nodes)
def _make_linear(self, nodes):
nodes_sorted = sorted(nodes, key=lambda x: (x[0], x[1]))
if not nodes_sorted: return
prev = nodes_sorted[0]
for nxt in nodes_sorted[1:]:
if not self._would_create_intersection(prev, nxt): self.graph.add_edge(prev, nxt)
prev = nxt
def _add_cluster_dense(self, nodes, max_edges):
edges_added = 0
nodes = list(nodes)
random.shuffle(nodes)
dist_limit = 10 if self.target_edges > 0 else 4
for i in range(len(nodes)):
for j in range(i + 1, len(nodes)):
if self.target_edges == 0 and edges_added >= max_edges: return
n1, n2 = nodes[i], nodes[j]
dist = max(abs(n1[0]-n2[0]), abs(n1[1]-n2[1]))
if dist <= dist_limit:
if not self._would_create_intersection(n1, n2):
self.graph.add_edge(n1, n2)
edges_added += 1
def _build_bottleneck_clusters(self, nodes):
self.graph.remove_edges_from(list(self.graph.edges()))
clusters, centers = self._spatial_cluster_nodes(nodes, k=self.bottleneck_cluster_count)
for cluster in clusters:
if len(cluster) < 2: continue
self._connect_cluster_by_nearby_growth(cluster)
self._add_cluster_dense(list(cluster), max_edges=max(1, int(3.5 * len(cluster))))
order = sorted(range(len(clusters)), key=lambda i: (centers[i][0], centers[i][1]))
for a_idx, b_idx in zip(order[:-1], order[1:]):
self._add_bottleneck_links(clusters[a_idx], clusters[b_idx], self.bottleneck_edges_per_link)
if not nx.is_connected(self.graph): self._force_connect_components()
def _force_connect_components(self):
components = list(nx.connected_components(self.graph))
while len(components) > 1:
c1, c2 = list(components[0]), list(components[1])
best_pair, min_dist = None, float('inf')
s1 = c1 if len(c1)<30 else random.sample(c1, 30)
s2 = c2 if len(c2)<30 else random.sample(c2, 30)
for u in s1:
for v in s2:
d = (u[0]-v[0])**2 + (u[1]-v[1])**2
if d < min_dist and not self._would_create_intersection(u, v):
min_dist, best_pair = d, (u, v)
if best_pair: self.graph.add_edge(best_pair[0], best_pair[1])
else: break
prev_len = len(components)
components = list(nx.connected_components(self.graph))
if len(components) == prev_len: break
def _spatial_cluster_nodes(self, nodes, k):
nodes = list(nodes)
if k >= len(nodes): return [[n] for n in nodes], nodes[:]
centers = random.sample(nodes, k)
clusters = [[] for _ in range(k)]
for n in nodes:
best_i = min(range(k), key=lambda i: max(abs(n[0]-centers[i][0]), abs(n[1]-centers[i][1])))
clusters[best_i].append(n)
return clusters, centers
def _connect_cluster_by_nearby_growth(self, cluster_nodes): self._connect_all_nodes_by_nearby_growth(cluster_nodes)
def _add_bottleneck_links(self, cluster_a, cluster_b, m):
pairs = []
for u in cluster_a:
for v in cluster_b:
dist = max(abs(u[0]-v[0]), abs(u[1]-v[1]))
pairs.append((dist, u, v))
pairs.sort(key=lambda t: t[0])
added = 0
for _, u, v in pairs:
if added >= m: break
if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
self.graph.add_edge(u, v)
added += 1
def _remove_intersections(self):
pass_no = 0
while pass_no < 5:
pass_no += 1
edges = list(self.graph.edges())
intersections = []
check_edges = random.sample(edges, 400) if len(edges) > 600 else edges
for i in range(len(check_edges)):
for j in range(i+1, len(check_edges)):
e1, e2 = check_edges[i], check_edges[j]
if self._segments_intersect(e1[0], e1[1], e2[0], e2[1]): intersections.append((e1, e2))
if not intersections: break
for e1, e2 in intersections:
if not self.graph.has_edge(*e1) or not self.graph.has_edge(*e2): continue
l1 = (e1[0][0]-e1[1][0])**2 + (e1[0][1]-e1[1][1])**2
l2 = (e2[0][0]-e2[1][0])**2 + (e2[0][1]-e2[1][1])**2
rem = e1 if l1 > l2 else e2
self.graph.remove_edge(*rem)
def _adjust_edges_to_target(self):
current_edges = list(self.graph.edges())
curr_count = len(current_edges)
if curr_count > self.target_edges:
to_remove = curr_count - self.target_edges
sorted_edges = sorted(current_edges, key=lambda e: (e[0][0]-e[1][0])**2 + (e[0][1]-e[1][1])**2, reverse=True)
for e in sorted_edges:
if len(self.graph.edges()) <= self.target_edges: break
self.graph.remove_edge(*e)
if not nx.is_connected(self.graph): self.graph.add_edge(*e)
elif curr_count < self.target_edges:
needed = self.target_edges - curr_count
nodes = list(self.graph.nodes())
attempts = 0
while len(self.graph.edges()) < self.target_edges and attempts < (needed * 30):
attempts += 1
u = random.choice(nodes)
candidates = sorted(nodes, key=lambda n: (n[0]-u[0])**2 + (n[1]-u[1])**2)
if len(candidates) < 2: continue
v = random.choice(candidates[1:min(len(candidates), 10)])
if not self.graph.has_edge(u, v) and not self._would_create_intersection(u, v):
self.graph.add_edge(u, v)
def _enforce_edge_budget(self):
budget = self._compute_edge_count()
while len(self.graph.edges()) > budget:
edges = list(self.graph.edges())
rem = random.choice(edges)
self.graph.remove_edge(*rem)
if not nx.is_connected(self.graph):
self.graph.add_edge(*rem)
break
def _segments_intersect(self, a, b, c, d):
if a == c or a == d or b == c or b == d: return False
def ccw(A,B,C): return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])
return ccw(a,c,d) != ccw(b,c,d) and ccw(a,b,c) != ccw(a,b,d)
def _would_create_intersection(self, u, v):
for a, b in self.graph.edges():
if u == a or u == b or v == a or v == b: continue
if self._segments_intersect(u, v, a, b): return True
return False
# === MANUAL EDITING ===
def manual_add_node(self, x, y):
# FIX: Force Int Cast to avoid "Already Exists" due to float mismatch
x, y = int(x), int(y)
if not (0 <= x <= self.width and 0 <= y <= self.height): return False, "Out of bounds."
if self.graph.has_node((x, y)): return False, "Already exists."
self.graph.add_node((x, y))
nodes = list(self.graph.nodes())
if len(nodes) > 1:
closest = min([n for n in nodes if n != (x,y)], key=lambda n: (n[0]-x)**2 + (n[1]-y)**2)
if not self._would_create_intersection((x,y), closest): self.graph.add_edge((x,y), closest)
return True, "Added."
def manual_delete_node_by_id(self, node_id):
sorted_nodes = get_sorted_nodes(self.graph)
try:
idx = int(node_id) - 1
if idx < 0 or idx >= len(sorted_nodes):
return False, f"ID {node_id} not found."
node_to_del = sorted_nodes[idx]
self.graph.remove_node(node_to_del)
if len(self.graph.nodes()) > 1 and not nx.is_connected(self.graph):
self._force_connect_components()
return True, f"Node {node_id} {node_to_del} removed."
except ValueError:
return False, "Invalid ID."
# ==========================================
# GRADIO HELPERS
# ==========================================
def plot_graph(graph, width, height, title="Network", highlight_node=None):
fig, ax = plt.subplots(figsize=(8, 8))
pos = {node: (node[0], node[1]) for node in graph.nodes()}
# 1. Edges
nx.draw_networkx_edges(graph, pos, ax=ax, width=2, alpha=0.6, edge_color="#333")
# 2. Nodes (Standard)
# Filter nodes that are NOT highlighted
normal_nodes = [n for n in graph.nodes() if n != highlight_node]
nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=normal_nodes, node_size=350, node_color="#4F46E5", edgecolors="white", linewidths=1.5)
# 3. Nodes (Highlight)
if highlight_node and graph.has_node(highlight_node):
nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=[highlight_node], node_size=400, node_color="#EF4444", edgecolors="white", linewidths=2.0)
sorted_nodes = get_sorted_nodes(graph)
labels = {node: str(i+1) for i, node in enumerate(sorted_nodes)}
nx.draw_networkx_labels(graph, pos, labels, ax=ax, font_size=8, font_color="white", font_weight="bold")
ax.set_xlim(-1, width + 1)
ax.set_ylim(-1, height + 1)
ax.invert_yaxis()
ax.grid(True, linestyle=':', alpha=0.3)
ax.set_axis_on()
ax.tick_params(left=True, bottom=True, labelleft=False, labelbottom=False)
ax.set_title(title)
return fig
def get_preset_dims(preset_mode, topology):
if preset_mode == "Custom": return gr.update(interactive=True), gr.update(interactive=True)
if topology == "linear":
dims = (4, 4) if preset_mode == "Small" else (6, 11) if preset_mode == "Medium" else (10, 26)
else:
dims = (4, 4) if preset_mode == "Small" else (8, 8) if preset_mode == "Medium" else (16, 16)
return gr.update(value=dims[0], interactive=False), gr.update(value=dims[1], interactive=False)
def update_ui_for_variant(variant, width, height, topology, void_frac):
is_custom = (variant == "Custom")
if is_custom:
temp_gen = NetworkGenerator(width, height, "F", topology, void_frac)
def_nodes, def_edges = temp_gen.calculate_defaults()
void_update = gr.update(interactive=True)
target_node_update = gr.update(value=def_nodes, interactive=True)
target_edge_update = gr.update(value=def_edges, interactive=True)
else:
area = width * height
val = 0.60 if area <= 20 else 0.35
void_update = gr.update(value=val, interactive=False)
target_node_update = gr.update(value=0, interactive=False)
target_edge_update = gr.update(value=0, interactive=False)
return void_update, target_node_update, target_edge_update
def save_single_visual_action(state_data):
if not state_data or "graph" not in state_data: return None
graph = state_data["graph"]
width = state_data["width"]
height = state_data["height"]
fig = plot_graph(graph, width, height, "Network Visual")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
fname = f"network_visual_{timestamp}.png"
fig.savefig(fname)
plt.close(fig)
return fname
def generate_and_store(topology, width, height, variant, void_frac, t_nodes, t_edges):
try:
var_code = "F" if variant == "Fixed" else "R"
if variant == "Fixed": t_nodes, t_edges = 0, 0
gen = NetworkGenerator(width, height, var_code, topology, void_frac, t_nodes, t_edges)
graph = gen.generate()
fig = plot_graph(graph, width, height, f"{topology} ({len(graph.nodes())}N, {len(graph.edges())}E)")
metrics = f"**Nodes:** {len(graph.nodes())} | **Edges:** {len(graph.edges())} | **Density:** {nx.density(graph):.2f}"
state_data = { "graph": graph, "width": width, "height": height, "topology": topology }
return fig, metrics, state_data, gr.update(interactive=True), gr.update(interactive=True)
except Exception as e:
return None, f"Error: {e}", None, gr.update(interactive=False), gr.update(interactive=False)
def manual_edit_action(action, x, y, node_id, state_data):
if not state_data or "graph" not in state_data: return None, "No graph.", state_data
gen = NetworkGenerator(state_data["width"], state_data["height"])
gen.graph = state_data["graph"]
# Store added node to pass to plotter
highlight = None
if action == "Add Node":
# Ensure Int here too
x, y = int(x), int(y)
success, msg = gen.manual_add_node(x, y)
if success: highlight = (x, y)
else:
success, msg = gen.manual_delete_node_by_id(node_id)
if success:
fig = plot_graph(gen.graph, state_data["width"], state_data["height"], "Edited", highlight_node=highlight)
metrics = f"**Nodes:** {len(gen.graph.nodes())} | **Edges:** {len(gen.graph.edges())} | {msg}"
state_data["graph"] = gen.graph
return fig, metrics, state_data
else:
return gr.update(), f"Error: {msg}", state_data
def run_batch_generation(count, topology, width, height, variant, void_frac, t_nodes, t_edges):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dir_name = f"batch_{timestamp}"
os.makedirs(dir_name, exist_ok=True)
var_code = "F" if variant == "Fixed" else "R"
if variant == "Fixed": t_nodes, t_edges = 0, 0
try:
for i in range(int(count)):
gen = NetworkGenerator(width, height, var_code, topology, void_frac, t_nodes, t_edges)
G = gen.generate()
json_content = generate_full_json_dict(G, loop=i+1)
with open(os.path.join(dir_name, f"inst_{i+1}.json"), 'w') as f:
json.dump(json_content, f, indent=4)
zip_path = shutil.make_archive(dir_name, 'zip', dir_name)
shutil.rmtree(dir_name)
return zip_path
except Exception as e:
return None
# ==========================================
# GRADIO UI
# ==========================================
with gr.Blocks(title="Graph Generator Pro") as demo:
state = gr.State()
gr.Markdown("# Spatial Network Generator Pro")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Config"):
topology = gr.Dropdown(["highly_connected", "bottlenecks", "linear"], value="highly_connected", label="Topology")
preset = gr.Radio(["Small", "Medium", "Large", "Custom"], value="Medium", label="Preset")
with gr.Row():
width = gr.Number(8, label="Width", interactive=False)
height = gr.Number(8, label="Height", interactive=False)
variant = gr.Dropdown(["Fixed", "Custom"], value="Fixed", label="Variant")
void_frac = gr.Slider(0.0, 0.9, 0.35, label="Void Fraction", interactive=False)
gr.Markdown("### Custom Overrides")
with gr.Row():
t_nodes = gr.Number(0, label="Nodes", interactive=False)
t_edges = gr.Number(0, label="Edges", interactive=False)
gen_btn = gr.Button("Generate", variant="primary")
save_viz_btn = gr.Button("Download Visual", interactive=False)
viz_file = gr.File(label="Saved Visual", interactive=False, visible=False)
with gr.Tab("Editor"):
with gr.Tab("Add"):
with gr.Row():
ed_x = gr.Number(0, label="X", precision=0)
ed_y = gr.Number(0, label="Y", precision=0)
btn_add = gr.Button("Add Node at (X,Y)")
with gr.Tab("Delete"):
ed_id = gr.Number(1, label="Node Number (ID)", precision=0)
btn_del = gr.Button("Delete Node ID")
with gr.Tab("Batch"):
batch_count = gr.Slider(1, 50, 5, step=1, label="Count")
batch_btn = gr.Button("Generate Batch ZIP")
file_out = gr.File(label="Download ZIP")
with gr.Column(scale=2):
metrics = gr.Markdown("Ready.")
plot = gr.Plot()
inputs_dims = [preset, topology]
preset.change(get_preset_dims, inputs_dims, [width, height])
topology.change(get_preset_dims, inputs_dims, [width, height])
inputs_var = [variant, width, height, topology, void_frac]
variant.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
width.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
height.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
topology.change(update_ui_for_variant, inputs_var, [void_frac, t_nodes, t_edges])
gen_args = [topology, width, height, variant, void_frac, t_nodes, t_edges]
gen_btn.click(generate_and_store, gen_args, [plot, metrics, state, save_viz_btn, viz_file])
save_viz_btn.click(save_single_visual_action, [state], [viz_file]).then(
lambda: gr.update(visible=True), None, [viz_file]
)
btn_add.click(manual_edit_action, [gr.State("Add Node"), ed_x, ed_y, gr.State(0), state], [plot, metrics, state])
btn_del.click(manual_edit_action, [gr.State("Del Node"), gr.State(0), gr.State(0), ed_id, state], [plot, metrics, state])
batch_args = [batch_count, topology, width, height, variant, void_frac, t_nodes, t_edges]
batch_btn.click(run_batch_generation, batch_args, [file_out])
if __name__ == "__main__":
demo.launch()