Spaces:
Sleeping
Sleeping
| """ | |
| Visual Component Library - Beginner-Friendly ML Visualizations | |
| Multiple synchronized views with educational annotations | |
| """ | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| COLORS = { | |
| 'primary': '#667eea', | |
| 'secondary': '#764ba2', | |
| 'accent': '#f093fb', | |
| 'success': '#43e97b', | |
| 'warning': '#fa709a', | |
| 'info': '#4facfe', | |
| 'gradient': ['#667eea', '#764ba2', '#f093fb', '#4facfe', '#43e97b', '#fa709a'], | |
| 'heatmap': 'Viridis', | |
| } | |
| LAYOUT_DEFAULTS = dict( | |
| template='plotly_dark', | |
| paper_bgcolor='rgba(26,26,46,0.9)', | |
| plot_bgcolor='rgba(26,26,46,0.9)', | |
| font=dict(family='Inter, sans-serif', color='#e0e0e0', size=12), | |
| margin=dict(l=60, r=60, t=80, b=60), | |
| hoverlabel=dict(bgcolor='#1a1a2e', font_size=14), | |
| ) | |
| def create_placeholder_figure(title: str, message: str) -> go.Figure: | |
| """ | |
| Create a placeholder figure when required data is missing. | |
| Shows a clear error message instead of wrong defaults. | |
| """ | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| x=0.5, y=0.5, | |
| xref='paper', yref='paper', | |
| text=f"<b>{title}</b><br><br>{message}", | |
| showarrow=False, | |
| font=dict(size=16, color='#ff6b6b'), | |
| align='center', | |
| bgcolor='rgba(255,107,107,0.1)', | |
| bordercolor='#ff6b6b', | |
| borderwidth=2, | |
| borderpad=20, | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>⚠️ {title}</b>", font=dict(size=16, color='#ff6b6b')), | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| **LAYOUT_DEFAULTS, | |
| ) | |
| return fig | |
| def scatter_cluster(config: dict) -> go.Figure: | |
| """K-Means clustering with educational annotations.""" | |
| title = config.get('title', 'K-Means Clustering') | |
| n_clusters = int(config.get('n_clusters') or config.get('k') or config.get('num_clusters') or 3) | |
| n_points = int(config.get('n_points') or config.get('data_points') or 150) | |
| show_centroids = config.get('show_centroids', True) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| # Generate clear, separated clusters for beginners | |
| centers = [] | |
| angle_step = 2 * np.pi / n_clusters | |
| radius = 4 | |
| for i in range(n_clusters): | |
| angle = i * angle_step | |
| centers.append([radius * np.cos(angle), radius * np.sin(angle)]) | |
| centers = np.array(centers) | |
| # Generate points around centers | |
| points_per_cluster = n_points // n_clusters | |
| points = [] | |
| labels = [] | |
| for i, center in enumerate(centers): | |
| cluster_points = center + np.random.randn(points_per_cluster, 2) * 1.2 | |
| points.extend(cluster_points) | |
| labels.extend([i] * points_per_cluster) | |
| points = np.array(points) | |
| labels = np.array(labels) | |
| fig = go.Figure() | |
| # Plot clusters | |
| for i in range(n_clusters): | |
| mask = labels == i | |
| color = COLORS['gradient'][i % len(COLORS['gradient'])] | |
| fig.add_trace(go.Scatter( | |
| x=points[mask, 0], y=points[mask, 1], | |
| mode='markers', | |
| marker=dict(size=10, color=color, opacity=0.7, line=dict(color='white', width=1)), | |
| name=f'Cluster {i+1}', | |
| hovertemplate=f'<b>Cluster {i+1}</b><br>Position: (%{{x:.1f}}, %{{y:.1f}})<extra></extra>', | |
| )) | |
| # Centroids with educational annotation | |
| if show_centroids: | |
| fig.add_trace(go.Scatter( | |
| x=centers[:, 0], y=centers[:, 1], | |
| mode='markers+text', | |
| marker=dict(size=25, color='white', symbol='x', line=dict(color='#333', width=3)), | |
| text=[f'C{i+1}' for i in range(n_clusters)], | |
| textposition='top center', | |
| textfont=dict(size=14, color='white'), | |
| name='Centroids (Cluster Centers)', | |
| hovertemplate='<b>Centroid %{text}</b><br>This is the "center" of the cluster<br>Position: (%{x:.1f}, %{y:.1f})<extra></extra>', | |
| )) | |
| # Educational annotations | |
| fig.add_annotation( | |
| x=centers[0, 0], y=centers[0, 1] + 2, | |
| text="← Centroid: The algorithm tries to<br>minimize distance from points to here", | |
| showarrow=True, arrowhead=2, arrowcolor=COLORS['info'], | |
| font=dict(size=11, color=COLORS['info']), | |
| ax=80, ay=-30 | |
| ) | |
| # Add insight annotation | |
| fig.add_annotation( | |
| x=0.02, y=0.98, xref='paper', yref='paper', | |
| text=f"<b>K = {n_clusters} clusters</b><br>{n_points} data points total", | |
| showarrow=False, | |
| font=dict(size=13, color='white'), | |
| align='left', | |
| bgcolor='rgba(102, 126, 234, 0.8)', | |
| borderpad=8 | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>Each color = one cluster, X = centroid</sub>", font=dict(size=18)), | |
| xaxis=dict(title='Feature 1', zeroline=False), | |
| yaxis=dict(title='Feature 2', zeroline=False), | |
| legend=dict(orientation='h', y=1.12), | |
| **LAYOUT_DEFAULTS | |
| ) | |
| return fig | |
| def cluster_distribution(config: dict) -> go.Figure: | |
| """ | |
| DYNAMIC cluster size distribution - synchronized with scatter_cluster. | |
| Uses same params as scatter_cluster (n_clusters, n_points, seed) to ensure | |
| the distribution matches the actual clustering visualization. | |
| """ | |
| title = config.get('title', 'Cluster Size Distribution') | |
| n_clusters = int(config.get('n_clusters') or config.get('k') or config.get('num_clusters') or 3) | |
| n_points = int(config.get('n_points') or config.get('data_points') or 150) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| # Generate same clustering as scatter_cluster to get accurate counts | |
| points_per_cluster = n_points // n_clusters | |
| remainder = n_points % n_clusters | |
| # Calculate actual cluster sizes (with small random variation for realism) | |
| cluster_sizes = [] | |
| cluster_names = [] | |
| for i in range(n_clusters): | |
| # Add slight variation to make it realistic (not perfectly equal) | |
| base_size = points_per_cluster + (1 if i < remainder else 0) | |
| variation = np.random.randint(-2, 3) if base_size > 5 else 0 | |
| size = max(1, base_size + variation) | |
| cluster_sizes.append(size) | |
| cluster_names.append(f'Cluster {i+1}') | |
| # Normalize to get proportions | |
| total = sum(cluster_sizes) | |
| proportions = [s / total for s in cluster_sizes] | |
| # Sort by size for better visualization | |
| sorted_pairs = sorted(zip(cluster_names, proportions, cluster_sizes), key=lambda x: x[1], reverse=True) | |
| cluster_names = [p[0] for p in sorted_pairs] | |
| proportions = [p[1] for p in sorted_pairs] | |
| cluster_sizes = [p[2] for p in sorted_pairs] | |
| colors = [COLORS['gradient'][i % len(COLORS['gradient'])] for i in range(n_clusters)] | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=cluster_names, y=proportions, | |
| marker=dict(color=colors, line=dict(color='white', width=1)), | |
| text=[f'{p:.1%}<br>({s} pts)' for p, s in zip(proportions, cluster_sizes)], | |
| textposition='outside', | |
| textfont=dict(size=12, color='white'), | |
| hovertemplate='<b>%{x}</b><br>Size: %{text}<extra></extra>', | |
| )) | |
| # Highlight largest cluster | |
| fig.add_annotation( | |
| x=cluster_names[0], y=proportions[0] + 0.05, | |
| text="👆 Largest", | |
| showarrow=False, font=dict(size=12, color='lime'), | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>Points per cluster (K={n_clusters})</sub>"), | |
| xaxis=dict(title='Cluster'), | |
| yaxis=dict(title='Proportion of Points', tickformat='.0%', range=[0, max(proportions) * 1.3]), | |
| **LAYOUT_DEFAULTS, | |
| ) | |
| return fig | |
| def gradient_descent_3d(config: dict) -> go.Figure: | |
| """3D gradient descent with educational annotations.""" | |
| title = config.get('title', 'Gradient Descent Optimization') | |
| lr = float(config.get('learning_rate') or config.get('lr') or config.get('alpha') or 0.1) | |
| start_x = float(config.get('start_x') or config.get('start_position') or 2.0) | |
| start_y = float(config.get('start_y') or 2.0) | |
| n_steps = int(config.get('n_steps') or config.get('steps') or config.get('iterations') or 30) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| # Create loss surface | |
| x = np.linspace(-3, 3, 60) | |
| y = np.linspace(-3, 3, 60) | |
| X, Y = np.meshgrid(x, y) | |
| Z = X**2 + Y**2 # Simple quadratic - easy to understand | |
| # Run gradient descent | |
| path_x, path_y, path_z = [start_x], [start_y], [start_x**2 + start_y**2] | |
| px, py = start_x, start_y | |
| for step in range(n_steps): | |
| # Gradient of x^2 + y^2 is (2x, 2y) | |
| gx, gy = 2 * px, 2 * py | |
| px = px - lr * gx | |
| py = py - lr * gy | |
| pz = px**2 + py**2 | |
| path_x.append(px) | |
| path_y.append(py) | |
| path_z.append(pz) | |
| # Stop if converged | |
| if pz < 0.001: | |
| break | |
| fig = go.Figure() | |
| # Loss surface | |
| fig.add_trace(go.Surface( | |
| x=X, y=Y, z=Z, | |
| colorscale='Viridis', | |
| opacity=0.85, | |
| showscale=False, | |
| name='Loss Surface', | |
| hovertemplate='Loss at (%{x:.1f}, %{y:.1f}): %{z:.2f}<extra></extra>', | |
| )) | |
| # Descent path | |
| fig.add_trace(go.Scatter3d( | |
| x=path_x, y=path_y, z=path_z, | |
| mode='lines+markers', | |
| marker=dict(size=6, color=COLORS['warning'], symbol='circle'), | |
| line=dict(color=COLORS['warning'], width=5), | |
| name='Optimization Path', | |
| hovertemplate='<b>Step %{pointNumber}</b><br>Position: (%{x:.2f}, %{y:.2f})<br>Loss: %{z:.3f}<extra></extra>', | |
| )) | |
| # Start marker | |
| fig.add_trace(go.Scatter3d( | |
| x=[path_x[0]], y=[path_y[0]], z=[path_z[0]], | |
| mode='markers+text', | |
| marker=dict(size=12, color='red', symbol='diamond'), | |
| text=['START'], | |
| textposition='top center', | |
| name='Starting Point', | |
| )) | |
| # End marker | |
| fig.add_trace(go.Scatter3d( | |
| x=[path_x[-1]], y=[path_y[-1]], z=[path_z[-1]], | |
| mode='markers+text', | |
| marker=dict(size=12, color='lime', symbol='diamond'), | |
| text=['END'], | |
| textposition='top center', | |
| name='Final Point', | |
| )) | |
| # Determine what happened | |
| final_loss = path_z[-1] | |
| if lr > 0.9: | |
| status = "⚠️ Learning rate too HIGH - overshooting!" | |
| status_color = 'red' | |
| elif lr < 0.05: | |
| status = "🐌 Learning rate too LOW - very slow progress" | |
| status_color = 'orange' | |
| elif final_loss < 0.1: | |
| status = "✅ Good convergence!" | |
| status_color = 'lime' | |
| else: | |
| status = f"📉 Loss: {final_loss:.3f} (keep optimizing)" | |
| status_color = 'yellow' | |
| fig.update_layout( | |
| title=dict( | |
| text=f"<b>{title}</b><br><sub>Learning Rate (α) = {lr} | Steps = {len(path_x)-1} | {status}</sub>", | |
| font=dict(size=16) | |
| ), | |
| scene=dict( | |
| xaxis_title='Parameter θ₁', | |
| yaxis_title='Parameter θ₂', | |
| zaxis_title='Loss J(θ)', | |
| camera=dict(eye=dict(x=1.8, y=1.8, z=1.2)), | |
| annotations=[ | |
| dict( | |
| x=0, y=0, z=0, | |
| text="← Global Minimum<br>(What we want to find!)", | |
| showarrow=True, | |
| arrowhead=2, | |
| font=dict(size=12, color='lime'), | |
| ) | |
| ] | |
| ), | |
| **LAYOUT_DEFAULTS, | |
| height=550, | |
| ) | |
| return fig | |
| def gradient_descent_2d(config: dict) -> go.Figure: | |
| """2D contour view of gradient descent - easier to understand for beginners.""" | |
| title = config.get('title', 'Gradient Descent (Top View)') | |
| lr = float(config.get('learning_rate') or config.get('lr') or 0.1) | |
| start_x = float(config.get('start_x') or 2.0) | |
| start_y = float(config.get('start_y') or 2.0) | |
| n_steps = int(config.get('n_steps') or config.get('iterations') or 30) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| # Create contour | |
| x = np.linspace(-3, 3, 100) | |
| y = np.linspace(-3, 3, 100) | |
| X, Y = np.meshgrid(x, y) | |
| Z = X**2 + Y**2 | |
| # Run gradient descent | |
| path_x, path_y = [start_x], [start_y] | |
| px, py = start_x, start_y | |
| for _ in range(n_steps): | |
| gx, gy = 2 * px, 2 * py | |
| px = px - lr * gx | |
| py = py - lr * gy | |
| path_x.append(px) | |
| path_y.append(py) | |
| if px**2 + py**2 < 0.001: | |
| break | |
| fig = go.Figure() | |
| # Contour plot | |
| fig.add_trace(go.Contour( | |
| x=x, y=y, z=Z, | |
| colorscale='Viridis', | |
| contours=dict(showlabels=True, labelfont=dict(size=10, color='white')), | |
| name='Loss Contours', | |
| hovertemplate='Loss: %{z:.2f}<extra></extra>', | |
| )) | |
| # Path | |
| fig.add_trace(go.Scatter( | |
| x=path_x, y=path_y, | |
| mode='lines+markers', | |
| marker=dict(size=8, color=COLORS['warning']), | |
| line=dict(color=COLORS['warning'], width=3, dash='solid'), | |
| name='Optimization Path', | |
| hovertemplate='Step %{pointNumber}<br>(%{x:.2f}, %{y:.2f})<extra></extra>', | |
| )) | |
| # Arrows showing direction | |
| for i in range(0, len(path_x)-1, max(1, len(path_x)//5)): | |
| fig.add_annotation( | |
| x=path_x[i+1], y=path_y[i+1], | |
| ax=path_x[i], ay=path_y[i], | |
| xref='x', yref='y', axref='x', ayref='y', | |
| showarrow=True, arrowhead=2, arrowsize=1.5, | |
| arrowcolor=COLORS['warning'] | |
| ) | |
| # Start and end | |
| fig.add_trace(go.Scatter( | |
| x=[path_x[0]], y=[path_y[0]], | |
| mode='markers+text', | |
| marker=dict(size=15, color='red', symbol='star'), | |
| text=['START'], textposition='top right', | |
| name='Start', | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=[path_x[-1]], y=[path_y[-1]], | |
| mode='markers+text', | |
| marker=dict(size=15, color='lime', symbol='star'), | |
| text=['END'], textposition='top right', | |
| name='End', | |
| )) | |
| # Minimum point | |
| fig.add_trace(go.Scatter( | |
| x=[0], y=[0], | |
| mode='markers+text', | |
| marker=dict(size=20, color='white', symbol='x'), | |
| text=['MINIMUM'], textposition='bottom center', | |
| name='Global Minimum', | |
| )) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>Bird's eye view - contour lines show equal loss</sub>"), | |
| xaxis=dict(title='θ₁', scaleanchor='y'), | |
| yaxis=dict(title='θ₂'), | |
| **LAYOUT_DEFAULTS, | |
| height=500, | |
| ) | |
| return fig | |
| def loss_curve(config: dict) -> go.Figure: | |
| """Training loss over time - shows convergence.""" | |
| title = config.get('title', 'Loss Over Training Steps') | |
| lr = float(config.get('learning_rate') or config.get('lr') or 0.1) | |
| start_x = float(config.get('start_x') or 2.0) | |
| start_y = float(config.get('start_y') or 2.0) | |
| n_steps = int(config.get('n_steps') or config.get('iterations') or 30) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| # Calculate loss at each step | |
| px, py = start_x, start_y | |
| losses = [px**2 + py**2] | |
| for _ in range(n_steps): | |
| gx, gy = 2 * px, 2 * py | |
| px = px - lr * gx | |
| py = py - lr * gy | |
| losses.append(px**2 + py**2) | |
| steps = list(range(len(losses))) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=steps, y=losses, | |
| mode='lines+markers', | |
| marker=dict(size=8, color=COLORS['primary']), | |
| line=dict(color=COLORS['primary'], width=3), | |
| fill='tozeroy', | |
| fillcolor='rgba(102, 126, 234, 0.2)', | |
| name='Training Loss', | |
| hovertemplate='Step %{x}<br>Loss: %{y:.4f}<extra></extra>', | |
| )) | |
| # Add annotations | |
| fig.add_annotation( | |
| x=0, y=losses[0], | |
| text=f"Starting Loss: {losses[0]:.2f}", | |
| showarrow=True, arrowhead=2, | |
| font=dict(size=11), ax=50, ay=-30 | |
| ) | |
| fig.add_annotation( | |
| x=len(losses)-1, y=losses[-1], | |
| text=f"Final Loss: {losses[-1]:.4f}", | |
| showarrow=True, arrowhead=2, | |
| font=dict(size=11, color='lime'), ax=-50, ay=-30 | |
| ) | |
| # Good/bad indicator | |
| if losses[-1] < 0.01: | |
| fig.add_hline(y=0.01, line_dash="dash", line_color="lime", | |
| annotation_text="✅ Converged!", annotation_position="right") | |
| elif losses[-1] > losses[0] * 0.9: | |
| fig.add_annotation( | |
| x=0.5, y=0.5, xref='paper', yref='paper', | |
| text="⚠️ Not converging well - try adjusting learning rate", | |
| font=dict(size=14, color='orange'), | |
| showarrow=False, bgcolor='rgba(0,0,0,0.7)', borderpad=10 | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>Watch the loss decrease as we optimize</sub>"), | |
| xaxis=dict(title='Training Step'), | |
| yaxis=dict(title='Loss', type='log' if max(losses) > 100 else 'linear'), | |
| **LAYOUT_DEFAULTS, | |
| height=400, | |
| ) | |
| return fig | |
| def flow_diagram(config: dict) -> go.Figure: | |
| """Neural network architecture with annotations.""" | |
| title = config.get('title', 'Neural Network Architecture') | |
| raw_layers = config.get('layers', [3, 4, 4, 2]) | |
| # Normalize layers | |
| layers = [] | |
| layer_names = ['Input', 'Hidden 1', 'Hidden 2', 'Hidden 3', 'Output'] | |
| layer_colors = [COLORS['info'], COLORS['primary'], COLORS['secondary'], COLORS['accent'], COLORS['success']] | |
| for i, layer in enumerate(raw_layers): | |
| if isinstance(layer, int): | |
| name = 'Input' if i == 0 else ('Output' if i == len(raw_layers)-1 else f'Hidden {i}') | |
| layers.append({'name': name, 'nodes': layer, 'color': layer_colors[i % len(layer_colors)]}) | |
| elif isinstance(layer, dict): | |
| layers.append({ | |
| 'name': layer.get('name', f'Layer {i+1}'), | |
| 'nodes': layer.get('nodes', 3), | |
| 'color': layer.get('color', layer_colors[i % len(layer_colors)]) | |
| }) | |
| fig = go.Figure() | |
| layer_x = np.linspace(0.1, 0.9, len(layers)) | |
| # Draw connections | |
| for i in range(len(layers) - 1): | |
| n1, n2 = layers[i]['nodes'], layers[i+1]['nodes'] | |
| y1 = np.linspace(0.2, 0.8, n1) | |
| y2 = np.linspace(0.2, 0.8, n2) | |
| for y1_pos in y1: | |
| for y2_pos in y2: | |
| weight = np.random.uniform(0.2, 1.0) | |
| fig.add_trace(go.Scatter( | |
| x=[layer_x[i], layer_x[i+1]], y=[y1_pos, y2_pos], | |
| mode='lines', | |
| line=dict(color=f'rgba(102, 126, 234, {weight * 0.5})', width=weight * 2), | |
| hoverinfo='skip', showlegend=False, | |
| )) | |
| # Draw nodes | |
| for i, (layer, x_pos) in enumerate(zip(layers, layer_x)): | |
| n = layer['nodes'] | |
| y_positions = np.linspace(0.2, 0.8, n) | |
| fig.add_trace(go.Scatter( | |
| x=[x_pos] * n, y=y_positions, | |
| mode='markers+text', | |
| marker=dict(size=35, color=layer['color'], line=dict(color='white', width=2)), | |
| text=[str(j+1) for j in range(n)], | |
| textposition='middle center', | |
| textfont=dict(color='white', size=11), | |
| name=layer['name'], | |
| hovertemplate=f"<b>{layer['name']}</b><br>Neuron %{{text}}<extra></extra>", | |
| )) | |
| # Layer label | |
| fig.add_annotation( | |
| x=x_pos, y=-0.05, text=f"<b>{layer['name']}</b><br>({n} neurons)", | |
| showarrow=False, font=dict(size=12, color=layer['color']), | |
| ) | |
| # Educational annotations | |
| fig.add_annotation( | |
| x=layer_x[0], y=0.95, | |
| text="📥 Input<br>Your data goes here", | |
| showarrow=False, font=dict(size=10, color=COLORS['info']), | |
| ) | |
| fig.add_annotation( | |
| x=layer_x[-1], y=0.95, | |
| text="📤 Output<br>Predictions come out here", | |
| showarrow=False, font=dict(size=10, color=COLORS['success']), | |
| ) | |
| if len(layers) > 2: | |
| mid = len(layers) // 2 | |
| fig.add_annotation( | |
| x=layer_x[mid], y=0.95, | |
| text="🧠 Hidden Layers<br>Learn patterns", | |
| showarrow=False, font=dict(size=10, color=COLORS['primary']), | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>Data flows left → right through connected neurons</sub>"), | |
| xaxis=dict(visible=False, range=[-0.05, 1.05]), | |
| yaxis=dict(visible=False, range=[-0.15, 1.05]), | |
| **LAYOUT_DEFAULTS, | |
| height=500, | |
| showlegend=False, | |
| ) | |
| return fig | |
| def matrix_heatmap(config: dict) -> go.Figure: | |
| """ | |
| PURE RENDERER - Matrix heatmap for attention weights, feature maps, confusion matrices. | |
| REQUIRES from LLM: | |
| - labels: list of row/column labels (e.g., ["Token 1", "Token 2"] or ["Filter A", "Filter B"]) | |
| - values: 2D array of values (e.g., [[0.1, 0.2], [0.3, 0.4]]) | |
| - x_title: title for x-axis (e.g., "Keys" for attention, "Feature" for CNN) | |
| - y_title: title for y-axis (e.g., "Queries" for attention, "Filter" for CNN) | |
| INTERACTIVE PARAMS: | |
| - focus_row: Index of row to highlight (1-indexed for user-friendliness) | |
| - threshold: Only show values above this threshold (0.0-1.0) | |
| Does NOT generate data internally - LLM must provide concept-specific data. | |
| """ | |
| title = config.get('title', 'Matrix Visualization') | |
| subtitle = config.get('subtitle', 'Brighter = higher value') | |
| colorbar_title = config.get('colorbar_title', 'Value') | |
| # REQUIRE data from LLM | |
| labels = config.get('labels') | |
| values = config.get('values') | |
| x_title = config.get('x_title') | |
| y_title = config.get('y_title') | |
| # If required data is missing, show clear error | |
| if not labels or not values or not x_title or not y_title: | |
| return create_placeholder_figure( | |
| "Missing Data for Matrix Heatmap", | |
| "LLM must provide: labels, values (2D array), x_title, y_title.<br><br>" | |
| "Example for Attention:<br>" | |
| "labels=['The', 'cat', 'sat'], x_title='Keys', y_title='Queries'<br>" | |
| "values=[[0.5, 0.3, 0.2], [0.1, 0.7, 0.2], [0.2, 0.2, 0.6]]" | |
| ) | |
| # Convert values to numpy array | |
| try: | |
| data = np.array(values, dtype=float) | |
| if data.ndim != 2: | |
| raise ValueError("values must be 2D") | |
| except (TypeError, ValueError) as e: | |
| return create_placeholder_figure( | |
| "Invalid Data", | |
| f"'values' must be a 2D array of numbers.<br>Error: {e}" | |
| ) | |
| # INTERACTIVE: Focus on specific row (1-indexed for user-friendliness) | |
| focus_row = config.get('focus_row') | |
| if focus_row is not None: | |
| focus_idx = int(focus_row) - 1 # Convert to 0-indexed | |
| if 0 <= focus_idx < data.shape[0]: | |
| # Dim other rows to highlight the focused one | |
| mask = np.ones_like(data) * 0.3 | |
| mask[focus_idx, :] = 1.0 | |
| data = data * mask | |
| subtitle = f'Focusing on row {focus_row}: "{labels[focus_idx]}"' | |
| # INTERACTIVE: Apply threshold filter | |
| threshold = config.get('threshold') | |
| if threshold is not None: | |
| threshold = float(threshold) | |
| data = np.where(data >= threshold, data, 0) | |
| subtitle = f'Showing values ≥ {threshold:.2f}' | |
| size = len(labels) | |
| hover_template = config.get('hover_template', '%{y} → %{x}<br>Value: %{z:.3f}<extra></extra>') | |
| fig = go.Figure() | |
| fig.add_trace(go.Heatmap( | |
| z=data, x=labels[:data.shape[1]], y=labels[:data.shape[0]], | |
| colorscale='Viridis', | |
| hovertemplate=hover_template, | |
| colorbar=dict(title=dict(text=colorbar_title, side='right'), thickness=15), | |
| )) | |
| # Highlight focused row with border | |
| if focus_row is not None and 0 <= int(focus_row) - 1 < data.shape[0]: | |
| focus_idx = int(focus_row) - 1 | |
| fig.add_shape( | |
| type='rect', | |
| x0=-0.5, x1=data.shape[1] - 0.5, | |
| y0=focus_idx - 0.5, y1=focus_idx + 0.5, | |
| line=dict(color='#ff6b6b', width=3), | |
| ) | |
| # Highlight diagonal if square matrix (only if no focus) | |
| if focus_row is None and data.shape[0] == data.shape[1]: | |
| for i in range(min(data.shape[0], len(labels))): | |
| if i < data.shape[0] and i < data.shape[1]: | |
| fig.add_annotation( | |
| x=i, y=i, | |
| text="●" if data[i, i] > 0.2 else "", | |
| showarrow=False, font=dict(size=8, color='red') | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>{subtitle}</sub>"), | |
| xaxis=dict(title=x_title, tickangle=45), | |
| yaxis=dict(title=y_title, autorange='reversed'), | |
| **LAYOUT_DEFAULTS, | |
| ) | |
| return fig | |
| def distribution_plot(config: dict) -> go.Figure: | |
| """ | |
| PURE RENDERER - Probability distribution visualization. | |
| REQUIRES from LLM: | |
| - categories: list of labels (e.g., ["Token 1", "Token 2"] or ["Dog", "Cat"]) | |
| - values: list of probabilities (should sum to ~1.0) | |
| INTERACTIVE PARAMS: | |
| - temperature: Adjusts distribution sharpness (0.1=peaked, 2.0=uniform) | |
| Does NOT generate data internally - LLM must provide concept-specific data. | |
| """ | |
| title = config.get('title', 'Probability Distribution') | |
| subtitle = config.get('subtitle', 'Output probabilities') | |
| x_title = config.get('x_title', 'Category') | |
| # REQUIRE data from LLM - no defaults! | |
| categories = config.get('categories') | |
| values = config.get('values') | |
| # If data is missing, show clear error instead of wrong defaults | |
| if not categories or not values: | |
| return create_placeholder_figure( | |
| "Missing Data for Distribution Plot", | |
| "LLM must provide 'categories' and 'values'.<br><br>" | |
| "Example: categories=['Token 1', 'Token 2'], values=[0.6, 0.4]" | |
| ) | |
| # Ensure values is a list of floats | |
| try: | |
| values = [float(v) for v in values] | |
| except (TypeError, ValueError): | |
| return create_placeholder_figure( | |
| "Invalid Data", | |
| f"'values' must be a list of numbers, got: {type(values)}" | |
| ) | |
| # INTERACTIVE: Apply temperature scaling (softmax with temperature) | |
| temperature = float(config.get('temperature', 1.0)) | |
| if temperature != 1.0 and temperature > 0: | |
| # Convert to logits (inverse softmax approximation), apply temperature, then softmax | |
| log_values = np.log(np.array(values) + 1e-10) | |
| scaled = log_values / temperature | |
| exp_values = np.exp(scaled - np.max(scaled)) # Numerical stability | |
| values = (exp_values / exp_values.sum()).tolist() | |
| subtitle = f'Temperature = {temperature:.1f} ({"sharper" if temperature < 1 else "smoother"})' | |
| # Sort by probability for better visualization | |
| sorted_pairs = sorted(zip(categories, values), key=lambda x: x[1], reverse=True) | |
| categories = [p[0] for p in sorted_pairs] | |
| values = [p[1] for p in sorted_pairs] | |
| colors = [COLORS['gradient'][i % len(COLORS['gradient'])] for i in range(len(categories))] | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=categories, y=values, | |
| marker=dict(color=colors, line=dict(color='white', width=1)), | |
| text=[f'{v:.1%}' for v in values], | |
| textposition='outside', | |
| textfont=dict(size=14, color='white'), | |
| hovertemplate='<b>%{x}</b><br>Probability: %{y:.2%}<extra></extra>', | |
| )) | |
| # Highlight winner | |
| if values: | |
| fig.add_annotation( | |
| x=categories[0], y=values[0] + 0.05, | |
| text="👆 Highest", | |
| showarrow=False, font=dict(size=12, color='lime'), | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>{subtitle}</sub>"), | |
| xaxis=dict(title=x_title), | |
| yaxis=dict(title='Probability', tickformat='.0%', range=[0, max(values) * 1.3] if values else [0, 1]), | |
| **LAYOUT_DEFAULTS, | |
| ) | |
| return fig | |
| def decision_boundary(config: dict) -> go.Figure: | |
| """Classification decision boundary.""" | |
| title = config.get('title', 'Decision Boundary') | |
| model_type = config.get('model_type', 'linear') | |
| n_points = int(config.get('n_points', 200)) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| # Generate data | |
| if model_type == 'circular': | |
| r1 = np.random.randn(n_points//2) * 0.5 + 1 | |
| r2 = np.random.randn(n_points//2) * 0.5 + 3 | |
| theta = np.random.rand(n_points) * 2 * np.pi | |
| r = np.concatenate([r1, r2]) | |
| X = np.column_stack([r * np.cos(theta), r * np.sin(theta)]) | |
| y = (r < 2).astype(int) | |
| else: # linear | |
| X = np.random.randn(n_points, 2) * 2 | |
| y = (X[:, 0] + X[:, 1] > 0).astype(int) | |
| fig = go.Figure() | |
| # Decision regions | |
| xx, yy = np.meshgrid(np.linspace(X[:,0].min()-1, X[:,0].max()+1, 100), | |
| np.linspace(X[:,1].min()-1, X[:,1].max()+1, 100)) | |
| if model_type == 'circular': | |
| Z = (np.sqrt(xx**2 + yy**2) < 2).astype(float) | |
| else: | |
| Z = (xx + yy > 0).astype(float) | |
| fig.add_trace(go.Contour( | |
| x=np.linspace(X[:,0].min()-1, X[:,0].max()+1, 100), | |
| y=np.linspace(X[:,1].min()-1, X[:,1].max()+1, 100), | |
| z=Z, | |
| colorscale=[[0, 'rgba(250,112,154,0.3)'], [1, 'rgba(79,172,254,0.3)']], | |
| showscale=False, contours=dict(showlines=True, coloring='fill'), | |
| hoverinfo='skip', | |
| )) | |
| # Data points | |
| for label, color, name in [(0, COLORS['warning'], 'Class A'), (1, COLORS['info'], 'Class B')]: | |
| mask = y == label | |
| fig.add_trace(go.Scatter( | |
| x=X[mask, 0], y=X[mask, 1], | |
| mode='markers', | |
| marker=dict(size=10, color=color, line=dict(color='white', width=1)), | |
| name=name, | |
| hovertemplate=f'<b>{name}</b><br>(%{{x:.1f}}, %{{y:.1f}})<extra></extra>', | |
| )) | |
| # Boundary annotation | |
| fig.add_annotation( | |
| x=0, y=0, | |
| text="← Decision Boundary<br>Model classifies differently<br>on each side", | |
| showarrow=True, arrowhead=2, | |
| font=dict(size=11, color='white'), | |
| ax=100, ay=-50 | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>Shaded regions show model's classification</sub>"), | |
| xaxis=dict(title='Feature 1'), | |
| yaxis=dict(title='Feature 2'), | |
| legend=dict(orientation='h', y=1.1), | |
| **LAYOUT_DEFAULTS, | |
| ) | |
| return fig | |
| def line_progression(config: dict) -> go.Figure: | |
| """Training curves with annotations.""" | |
| title = config.get('title', 'Training Progress') | |
| epochs = int(config.get('epochs', 50)) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| x = np.arange(1, epochs + 1) | |
| # Generate realistic curves | |
| train_loss = 2.0 * np.exp(-0.08 * x) + 0.1 + np.random.randn(epochs) * 0.03 | |
| val_loss = 2.2 * np.exp(-0.06 * x) + 0.15 + np.random.randn(epochs) * 0.05 | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=x, y=train_loss, mode='lines', | |
| name='Training Loss', line=dict(color=COLORS['primary'], width=3), | |
| hovertemplate='Epoch %{x}<br>Train Loss: %{y:.4f}<extra></extra>', | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=x, y=val_loss, mode='lines', | |
| name='Validation Loss', line=dict(color=COLORS['warning'], width=3, dash='dash'), | |
| hovertemplate='Epoch %{x}<br>Val Loss: %{y:.4f}<extra></extra>', | |
| )) | |
| # Overfit annotation if applicable | |
| if val_loss[-1] > val_loss[epochs//2]: | |
| overfit_start = epochs // 2 | |
| fig.add_vrect(x0=overfit_start, x1=epochs, fillcolor='red', opacity=0.1) | |
| fig.add_annotation( | |
| x=overfit_start + 10, y=val_loss.max(), | |
| text="⚠️ Overfitting zone<br>Val loss increasing", | |
| showarrow=False, font=dict(color='red', size=11), | |
| ) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b><br><sub>Lower is better - watch for val loss going up</sub>"), | |
| xaxis=dict(title='Epoch'), | |
| yaxis=dict(title='Loss'), | |
| legend=dict(orientation='h', y=1.1), | |
| hovermode='x unified', | |
| **LAYOUT_DEFAULTS, | |
| ) | |
| return fig | |
| def comparison_bars(config: dict) -> go.Figure: | |
| """Model comparison bars.""" | |
| title = config.get('title', 'Model Comparison') | |
| categories = config.get('categories', ['Accuracy', 'Precision', 'Recall', 'F1']) | |
| seed = config.get('seed', 42) | |
| np.random.seed(seed) | |
| groups = [ | |
| {'name': 'Model A', 'values': [0.92, 0.89, 0.94, 0.91], 'color': COLORS['primary']}, | |
| {'name': 'Model B', 'values': [0.88, 0.91, 0.85, 0.88], 'color': COLORS['warning']}, | |
| ] | |
| fig = go.Figure() | |
| for group in groups: | |
| fig.add_trace(go.Bar( | |
| x=categories, y=group['values'], | |
| name=group['name'], | |
| marker=dict(color=group['color'], line=dict(color='white', width=1)), | |
| text=[f'{v:.0%}' for v in group['values']], | |
| textposition='outside', | |
| hovertemplate=f"<b>{group['name']}</b><br>%{{x}}: %{{y:.1%}}<extra></extra>", | |
| )) | |
| fig.update_layout( | |
| title=dict(text=f"<b>{title}</b>"), | |
| barmode='group', | |
| yaxis=dict(tickformat='.0%', range=[0, 1.15]), | |
| legend=dict(orientation='h', y=1.1), | |
| **LAYOUT_DEFAULTS, | |
| ) | |
| return fig | |
| # Valid Plotly trace types for validation | |
| VALID_TRACE_TYPES = { | |
| 'scatter', 'scatter3d', 'scattergl', 'scatterpolar', 'scattergeo', | |
| 'bar', 'histogram', 'histogram2d', 'box', 'violin', | |
| 'heatmap', 'contour', 'surface', 'mesh3d', | |
| 'pie', 'sunburst', 'treemap', 'sankey', 'funnel', | |
| 'indicator', 'table', 'carpet', 'cone', 'streamtube', | |
| 'isosurface', 'volume', 'image', 'candlestick', 'ohlc' | |
| } | |
| def validate_trace(trace: dict) -> bool: | |
| """Validate a Plotly trace dict has valid structure.""" | |
| if not isinstance(trace, dict): | |
| return False | |
| trace_type = trace.get('type', 'scatter') | |
| return trace_type in VALID_TRACE_TYPES | |
| def custom_plotly(config: dict) -> go.Figure: | |
| """ | |
| Render arbitrary Plotly visualization from JSON spec. | |
| This enables the LLM to generate ANY visualization for ANY ML concept | |
| by providing the full Plotly JSON specification. | |
| Config should contain: | |
| - 'data': list of trace dicts (required) | |
| - 'layout': layout dict (optional, merged with defaults) | |
| - 'title': override title (optional) | |
| - 'template': name of a pre-defined template to use (optional) | |
| Security: This does NOT execute code - it only parses JSON into Plotly objects. | |
| """ | |
| title = config.get('title', 'Visualization') | |
| data = config.get('data', []) | |
| layout = config.get('layout', {}) | |
| template_name = config.get('template') | |
| # Load template if specified | |
| if template_name: | |
| try: | |
| from templates import get_template | |
| template = get_template(template_name) | |
| if template: | |
| # Merge template data with provided data | |
| data = template.get('base_data', []) + data | |
| # Merge template layout | |
| template_layout = template.get('layout', {}) | |
| layout = {**template_layout, **layout} | |
| # Add template annotations | |
| if 'annotations' in template: | |
| layout['annotations'] = layout.get('annotations', []) + template['annotations'] | |
| except ImportError: | |
| print(f"Templates module not found, skipping template: {template_name}") | |
| # Validate data structure | |
| if not data or not isinstance(data, list): | |
| print("custom_plotly: No valid data provided, falling back to default") | |
| return scatter_cluster({'title': title, 'n_clusters': 3}) | |
| # Filter and validate traces | |
| valid_traces = [] | |
| for i, trace in enumerate(data): | |
| if validate_trace(trace): | |
| valid_traces.append(trace) | |
| else: | |
| print(f"custom_plotly: Skipping invalid trace at index {i}: {type(trace)}") | |
| if not valid_traces: | |
| print("custom_plotly: No valid traces found, falling back to default") | |
| return scatter_cluster({'title': title, 'n_clusters': 3}) | |
| # Limit trace count to prevent memory issues | |
| MAX_TRACES = 50 | |
| if len(valid_traces) > MAX_TRACES: | |
| print(f"custom_plotly: Limiting traces from {len(valid_traces)} to {MAX_TRACES}") | |
| valid_traces = valid_traces[:MAX_TRACES] | |
| # Create figure from validated data | |
| try: | |
| fig = go.Figure(data=valid_traces) | |
| except Exception as e: | |
| print(f"custom_plotly: Error creating figure: {e}") | |
| return scatter_cluster({'title': title, 'n_clusters': 3}) | |
| # Merge layout with defaults | |
| merged_layout = {**LAYOUT_DEFAULTS} | |
| # Apply custom layout (safely) | |
| safe_layout_keys = [ | |
| 'title', 'xaxis', 'yaxis', 'zaxis', 'showlegend', 'legend', | |
| 'annotations', 'shapes', 'images', 'height', 'width', | |
| 'scene', 'geo', 'mapbox', 'polar', 'ternary', | |
| 'coloraxis', 'hovermode', 'dragmode', 'barmode', 'bargap' | |
| ] | |
| for key in safe_layout_keys: | |
| if key in layout: | |
| merged_layout[key] = layout[key] | |
| # Set title with styling | |
| if title: | |
| merged_layout['title'] = dict( | |
| text=f"<b>{title}</b>", | |
| font=dict(size=18) | |
| ) | |
| fig.update_layout(**merged_layout) | |
| return fig | |
| # Registry | |
| COMPONENTS = { | |
| 'scatter_cluster': scatter_cluster, | |
| 'cluster_distribution': cluster_distribution, | |
| 'gradient_descent_3d': gradient_descent_3d, | |
| 'gradient_descent_2d': gradient_descent_2d, | |
| 'loss_curve': loss_curve, | |
| 'flow_diagram': flow_diagram, | |
| 'matrix_heatmap': matrix_heatmap, | |
| 'distribution_plot': distribution_plot, | |
| 'decision_boundary': decision_boundary, | |
| 'line_progression': line_progression, | |
| 'comparison_bars': comparison_bars, | |
| 'custom_plotly': custom_plotly, | |
| } | |
| def render_component(component_type: str, config: dict) -> go.Figure: | |
| """Render a component by type.""" | |
| if component_type not in COMPONENTS: | |
| # Fallback to scatter_cluster for unknown types | |
| print(f"Unknown component {component_type}, using scatter_cluster") | |
| return scatter_cluster(config) | |
| return COMPONENTS[component_type](config) | |