Spaces:
Sleeping
Sleeping
| """ | |
| Visualization Module - Generate concept knowledge graphs | |
| """ | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| import matplotlib | |
| import io | |
| import base64 | |
| import os | |
| from typing import Dict, Any, List | |
| # Ensure using Agg backend (no need for GUI) | |
| matplotlib.use('Agg') | |
| # Set up Chinese font support | |
| # Try to find suitable Chinese fonts | |
| font_found = False | |
| chinese_fonts = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'AR PL UMing CN', 'STSong', 'NSimSun', 'FangSong', 'KaiTi'] | |
| for font in chinese_fonts: | |
| try: | |
| matplotlib.font_manager.findfont(font) | |
| matplotlib.rcParams['font.sans-serif'] = [font, 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif'] | |
| print(f"Using Chinese font: {font}") | |
| font_found = True | |
| break | |
| except: | |
| continue | |
| if not font_found: | |
| print("Warning: No suitable Chinese font found, using default font") | |
| matplotlib.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial Unicode MS', 'sans-serif'] | |
| matplotlib.rcParams['axes.unicode_minus'] = False | |
| matplotlib.rcParams['font.size'] = 10 | |
| def create_network_graph(concepts_data: Dict[str, Any]) -> str: | |
| """ | |
| Create an enhanced network visualization of concept relationships | |
| Args: | |
| concepts_data: Dictionary containing concept hierarchy and relationships | |
| Returns: | |
| Base64 encoded PNG image as data URL | |
| """ | |
| G = nx.DiGraph() | |
| # Clear any existing plots | |
| plt.clf() | |
| plt.close('all') | |
| # Increase figure size and DPI for better display | |
| plt.figure(figsize=(14, 10), dpi=150, facecolor='white') | |
| # Add nodes with difficulty-based colors | |
| difficulty_colors = { | |
| 'basic': '#90CAF9', # Light blue | |
| 'intermediate': '#FFB74D', # Orange | |
| 'advanced': '#EF5350' # Red | |
| } | |
| # Only add subconcepts (skip main concept) | |
| for concept in concepts_data.get("sub_concepts", []): | |
| concept_id = concept.get("id") | |
| concept_name = concept.get("name") | |
| difficulty = concept.get("difficulty", "basic") | |
| if concept_id and concept_name: | |
| G.add_node( | |
| concept_id, | |
| name=concept_name, | |
| type="sub", | |
| difficulty=difficulty, | |
| color=difficulty_colors.get(difficulty, '#90CAF9') | |
| ) | |
| # Add relationships between subconcepts only | |
| for relation in concepts_data.get("relationships", []): | |
| source = relation.get("source") | |
| target = relation.get("target") | |
| rel_type = relation.get("type") | |
| # Skip relationships involving main concept | |
| if (source and target and | |
| source in G.nodes and target in G.nodes): # Only add edges between existing subconcepts | |
| G.add_edge( | |
| source, | |
| target, | |
| type=rel_type | |
| ) | |
| # Optimize layout parameters and increase node spacing | |
| pos = nx.spring_layout( | |
| G, | |
| k=2.0, # Increase node spacing | |
| iterations=100, # Increase iterations for better layout | |
| seed=42 # Fixed random seed for consistent layout | |
| ) | |
| # Draw nodes with difficulty-based colors | |
| node_colors = [G.nodes[node].get('color', '#90CAF9') for node in G.nodes()] | |
| # All nodes are now the same size since there's no main concept | |
| node_sizes = [1500 for _ in G.nodes()] | |
| # Draw nodes | |
| nx.draw_networkx_nodes( | |
| G, pos, | |
| node_color=node_colors, | |
| node_size=node_sizes, | |
| alpha=0.8 | |
| ) | |
| # Draw edges with different styles for different relationship types | |
| edges_prerequisite = [(u, v) for (u, v, d) in G.edges(data=True) if d.get('type') == 'prerequisite'] | |
| edges_related = [(u, v) for (u, v, d) in G.edges(data=True) if d.get('type') == 'related'] | |
| # Draw edges with curves to avoid overlap | |
| nx.draw_networkx_edges( | |
| G, pos, | |
| edgelist=edges_prerequisite, | |
| edge_color='red', | |
| width=2, | |
| connectionstyle="arc3,rad=0.2", # Add curve | |
| arrowsize=20, | |
| arrowstyle='->', | |
| min_source_margin=30, | |
| min_target_margin=30 | |
| ) | |
| nx.draw_networkx_edges( | |
| G, pos, | |
| edgelist=edges_related, | |
| edge_color='blue', | |
| style='dashed', | |
| width=1.5, | |
| connectionstyle="arc3,rad=-0.2", # Add reverse curve | |
| arrowsize=15, | |
| arrowstyle='->', | |
| min_source_margin=25, | |
| min_target_margin=25 | |
| ) | |
| # Optimize label display | |
| labels = { | |
| node: G.nodes[node].get('name', node) | |
| for node in G.nodes() | |
| } | |
| # Calculate label position offsets | |
| label_pos = { | |
| node: (coord[0], coord[1] + 0.08) # Offset labels upward | |
| for node, coord in pos.items() | |
| } | |
| # Use larger font size and add text background | |
| nx.draw_networkx_labels( | |
| G, | |
| label_pos, | |
| labels, | |
| font_size=12, # Increase font size | |
| font_weight='bold', | |
| bbox={ # Add text background | |
| 'facecolor': 'white', | |
| 'edgecolor': '#E0E0E0', | |
| 'alpha': 0.9, | |
| 'pad': 6, | |
| 'boxstyle': 'round,pad=0.5' | |
| } | |
| ) | |
| # Adjust legend position and size | |
| legend_elements = [ | |
| plt.Line2D([0], [0], color='red', lw=2, label='Prerequisite'), | |
| plt.Line2D([0], [0], color='blue', linestyle='--', lw=2, label='Related'), | |
| plt.Line2D([0], [0], marker='o', color='w', label='Basic', markerfacecolor='#90CAF9', markersize=12), | |
| plt.Line2D([0], [0], marker='o', color='w', label='Intermediate', markerfacecolor='#FFB74D', markersize=12), | |
| plt.Line2D([0], [0], marker='o', color='w', label='Advanced', markerfacecolor='#EF5350', markersize=12) | |
| ] | |
| plt.legend( | |
| handles=legend_elements, | |
| loc='upper right', | |
| bbox_to_anchor=(1.2, 1), | |
| fontsize=10, | |
| frameon=True, | |
| facecolor='white', | |
| edgecolor='none', | |
| shadow=True | |
| ) | |
| # Add title showing the main concept without creating a node for it | |
| main_concept = concepts_data.get("main_concept", "Concept Map") | |
| plt.title(f"Concept Map: {main_concept}", pad=20, fontsize=14, fontweight='bold') | |
| # Increase graph margins | |
| plt.margins(x=0.2, y=0.2) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| # Add padding when saving the image | |
| buf = io.BytesIO() | |
| plt.savefig( | |
| buf, | |
| format='png', | |
| bbox_inches='tight', | |
| dpi=150, | |
| pad_inches=0.5 | |
| ) | |
| plt.close('all') | |
| buf.seek(0) | |
| return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode('utf-8') |