Spaces:
Sleeping
Sleeping
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| from datetime import datetime | |
| import os | |
| import shutil | |
| IMG_WIDTH_PX = 800 | |
| IMG_HEIGHT_PX = 800 | |
| TEMP_DIR = "temp_visuals" | |
| # --- NEW CLEANUP LOGIC --- | |
| # Wipe the folder clean when the app starts, then recreate it | |
| if os.path.exists(TEMP_DIR): | |
| shutil.rmtree(TEMP_DIR) | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| # ------------------------- | |
| def get_sorted_nodes(G): | |
| """Returns nodes sorted by X, then Y to ensure consistent IDs.""" | |
| return sorted(list(G.nodes()), key=lambda l: (l[0], l[1])) | |
| def plot_graph_to_image(graph, width, height, title="Network", highlight_node=None, save_dir=TEMP_DIR): | |
| """Generates a matplotlib plot and saves it as an image file.""" | |
| dpi = 100 | |
| fig = plt.figure(figsize=(IMG_WIDTH_PX/dpi, IMG_HEIGHT_PX/dpi), dpi=dpi) | |
| ax = fig.add_axes([0, 0, 1, 1]) | |
| pos = {node: (node[0], node[1]) for node in graph.nodes()} | |
| # Dynamic sizing to prevent jamming | |
| max_dim = max(width, height) | |
| if max_dim <= 6: | |
| n_sz, f_sz, h_sz = 900, 12, 1100 | |
| elif max_dim <= 10: | |
| n_sz, f_sz, h_sz = 500, 9, 650 | |
| elif max_dim <= 16: | |
| n_sz, f_sz, h_sz = 200, 7, 280 | |
| elif max_dim <= 24: | |
| n_sz, f_sz, h_sz = 100, 5, 140 | |
| else: | |
| n_sz, f_sz, h_sz = 50, 4, 80 | |
| nx.draw_networkx_edges(graph, pos, ax=ax, width=2, alpha=0.6, edge_color="#333") | |
| normal_nodes = [n for n in graph.nodes() if n != highlight_node] | |
| nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=normal_nodes, node_size=n_sz, node_color="#4F46E5", edgecolors="white", linewidths=1.5) | |
| if highlight_node and graph.has_node(highlight_node): | |
| nx.draw_networkx_nodes(graph, pos, ax=ax, nodelist=[highlight_node], node_size=h_sz, node_color="#EF4444", edgecolors="white", linewidths=2.0) | |
| sorted_nodes = get_sorted_nodes(graph) | |
| labels = {node: str(i+1) for i, node in enumerate(sorted_nodes)} | |
| nx.draw_networkx_labels(graph, pos, labels, ax=ax, font_size=f_sz, font_color="white", font_weight="bold") | |
| ax.set_xlim(-0.5, width + 0.5) | |
| ax.set_ylim(height + 0.5, -0.5) | |
| ax.grid(True, linestyle=':', alpha=0.3) | |
| ax.set_axis_on() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| prefix = "temp_plot" if save_dir == TEMP_DIR else "saved_plot" | |
| fname = os.path.join(save_dir, f"{prefix}_{timestamp}.png") | |
| plt.savefig(fname) | |
| plt.close(fig) | |
| return fname |