Spaces:
Sleeping
Sleeping
| """ | |
| Template Library for Complex ML Architecture Visualizations | |
| These templates provide base Plotly JSON specs for complex ML concepts | |
| that the LLM can reference and customize. | |
| """ | |
| COLORS = { | |
| 'primary': '#667eea', | |
| 'secondary': '#764ba2', | |
| 'accent': '#f093fb', | |
| 'success': '#43e97b', | |
| 'warning': '#fa709a', | |
| 'info': '#4facfe', | |
| 'query': '#ff6b6b', | |
| 'key': '#4ecdc4', | |
| 'value': '#45b7d1', | |
| } | |
| TEMPLATES = { | |
| # Transformer Attention Q/K/V Architecture | |
| 'transformer_attention': { | |
| 'base_data': [ | |
| # Query box | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.15], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 60, 'color': COLORS['query'], 'symbol': 'square'}, | |
| 'text': ['Q'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 24, 'color': 'white'}, | |
| 'name': 'Query', | |
| 'hovertemplate': '<b>Query (Q)</b><br>Represents "what am I looking for?"<extra></extra>' | |
| }, | |
| # Key box | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.5], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 60, 'color': COLORS['key'], 'symbol': 'square'}, | |
| 'text': ['K'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 24, 'color': 'white'}, | |
| 'name': 'Key', | |
| 'hovertemplate': '<b>Key (K)</b><br>Represents "what do I contain?"<extra></extra>' | |
| }, | |
| # Value box | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.85], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 60, 'color': COLORS['value'], 'symbol': 'square'}, | |
| 'text': ['V'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 24, 'color': 'white'}, | |
| 'name': 'Value', | |
| 'hovertemplate': '<b>Value (V)</b><br>The actual information to retrieve<extra></extra>' | |
| }, | |
| # Attention scores (Q·K) | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.325], 'y': [0.4], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 50, 'color': COLORS['accent'], 'symbol': 'diamond'}, | |
| 'text': ['Q·Kᵀ'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 16, 'color': 'white'}, | |
| 'name': 'Attention Scores', | |
| 'hovertemplate': '<b>Attention Scores</b><br>Dot product of Q and K<br>Shows how much each token attends to others<extra></extra>' | |
| }, | |
| # Softmax | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.325], 'y': [0.2], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 45, 'color': COLORS['warning'], 'symbol': 'circle'}, | |
| 'text': ['softmax'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 12, 'color': 'white'}, | |
| 'name': 'Softmax', | |
| 'hovertemplate': '<b>Softmax</b><br>Normalizes scores to probabilities (sum to 1)<extra></extra>' | |
| }, | |
| # Output | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.675], 'y': [0.2], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 50, 'color': COLORS['success'], 'symbol': 'square'}, | |
| 'text': ['Output'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 14, 'color': 'white'}, | |
| 'name': 'Output', | |
| 'hovertemplate': '<b>Attention Output</b><br>Weighted sum of Values<br>Based on attention weights<extra></extra>' | |
| }, | |
| ], | |
| 'layout': { | |
| 'xaxis': {'visible': False, 'range': [-0.1, 1.1]}, | |
| 'yaxis': {'visible': False, 'range': [-0.1, 1.0]}, | |
| 'showlegend': False, | |
| 'height': 500, | |
| 'shapes': [ | |
| # Arrow Q to Q·K | |
| {'type': 'line', 'x0': 0.15, 'y0': 0.62, 'x1': 0.28, 'y1': 0.45, | |
| 'line': {'color': COLORS['query'], 'width': 3}}, | |
| # Arrow K to Q·K | |
| {'type': 'line', 'x0': 0.5, 'y0': 0.62, 'x1': 0.37, 'y1': 0.45, | |
| 'line': {'color': COLORS['key'], 'width': 3}}, | |
| # Arrow Q·K to softmax | |
| {'type': 'line', 'x0': 0.325, 'y0': 0.33, 'x1': 0.325, 'y1': 0.25, | |
| 'line': {'color': COLORS['accent'], 'width': 3}}, | |
| # Arrow softmax to output | |
| {'type': 'line', 'x0': 0.38, 'y0': 0.2, 'x1': 0.6, 'y1': 0.2, | |
| 'line': {'color': COLORS['warning'], 'width': 3}}, | |
| # Arrow V to output | |
| {'type': 'line', 'x0': 0.85, 'y0': 0.62, 'x1': 0.72, 'y1': 0.25, | |
| 'line': {'color': COLORS['value'], 'width': 3}}, | |
| ], | |
| }, | |
| 'annotations': [ | |
| {'x': 0.15, 'y': 0.85, 'text': '<b>Query</b><br>What am I looking for?', | |
| 'showarrow': False, 'font': {'size': 11, 'color': COLORS['query']}}, | |
| {'x': 0.5, 'y': 0.85, 'text': '<b>Key</b><br>What do I contain?', | |
| 'showarrow': False, 'font': {'size': 11, 'color': COLORS['key']}}, | |
| {'x': 0.85, 'y': 0.85, 'text': '<b>Value</b><br>Actual content', | |
| 'showarrow': False, 'font': {'size': 11, 'color': COLORS['value']}}, | |
| {'x': 0.5, 'y': 0.02, 'text': 'Attention(Q,K,V) = softmax(QKᵀ/√d)V', | |
| 'showarrow': False, 'font': {'size': 14, 'color': 'white'}, | |
| 'bgcolor': 'rgba(102, 126, 234, 0.8)', 'borderpad': 8}, | |
| ] | |
| }, | |
| # LSTM Gate Architecture | |
| 'lstm_gates': { | |
| 'base_data': [ | |
| # Forget gate | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.2], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 55, 'color': '#e74c3c', 'symbol': 'square'}, | |
| 'text': ['f'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 20, 'color': 'white'}, | |
| 'name': 'Forget Gate', | |
| 'hovertemplate': '<b>Forget Gate (f)</b><br>Decides what to forget from cell state<br>f = σ(W·[h, x] + b)<extra></extra>' | |
| }, | |
| # Input gate | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.4], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 55, 'color': '#2ecc71', 'symbol': 'square'}, | |
| 'text': ['i'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 20, 'color': 'white'}, | |
| 'name': 'Input Gate', | |
| 'hovertemplate': '<b>Input Gate (i)</b><br>Decides what new info to store<br>i = σ(W·[h, x] + b)<extra></extra>' | |
| }, | |
| # Cell state candidate | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.6], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 55, 'color': '#9b59b6', 'symbol': 'square'}, | |
| 'text': ['C̃'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 20, 'color': 'white'}, | |
| 'name': 'Candidate', | |
| 'hovertemplate': '<b>Candidate Cell State (C̃)</b><br>New candidate values<br>C̃ = tanh(W·[h, x] + b)<extra></extra>' | |
| }, | |
| # Output gate | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.8], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 55, 'color': '#3498db', 'symbol': 'square'}, | |
| 'text': ['o'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 20, 'color': 'white'}, | |
| 'name': 'Output Gate', | |
| 'hovertemplate': '<b>Output Gate (o)</b><br>Decides what to output<br>o = σ(W·[h, x] + b)<extra></extra>' | |
| }, | |
| # Cell state line | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.1, 0.9], 'y': [0.4, 0.4], | |
| 'mode': 'lines', | |
| 'line': {'color': '#f39c12', 'width': 8}, | |
| 'name': 'Cell State', | |
| 'hovertemplate': '<b>Cell State (C)</b><br>The "memory" that flows through time<extra></extra>' | |
| }, | |
| # Hidden state output | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.8], 'y': [0.15], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 45, 'color': '#1abc9c', 'symbol': 'circle'}, | |
| 'text': ['hₜ'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 16, 'color': 'white'}, | |
| 'name': 'Hidden State', | |
| 'hovertemplate': '<b>Hidden State (hₜ)</b><br>Output at time t<br>hₜ = o * tanh(Cₜ)<extra></extra>' | |
| }, | |
| ], | |
| 'layout': { | |
| 'xaxis': {'visible': False, 'range': [0, 1]}, | |
| 'yaxis': {'visible': False, 'range': [0, 1]}, | |
| 'showlegend': False, | |
| 'height': 500, | |
| }, | |
| 'annotations': [ | |
| {'x': 0.2, 'y': 0.88, 'text': '<b>Forget</b>', 'showarrow': False, | |
| 'font': {'size': 12, 'color': '#e74c3c'}}, | |
| {'x': 0.4, 'y': 0.88, 'text': '<b>Input</b>', 'showarrow': False, | |
| 'font': {'size': 12, 'color': '#2ecc71'}}, | |
| {'x': 0.6, 'y': 0.88, 'text': '<b>Candidate</b>', 'showarrow': False, | |
| 'font': {'size': 12, 'color': '#9b59b6'}}, | |
| {'x': 0.8, 'y': 0.88, 'text': '<b>Output</b>', 'showarrow': False, | |
| 'font': {'size': 12, 'color': '#3498db'}}, | |
| {'x': 0.5, 'y': 0.4, 'text': '← Cell State (Long-term memory) →', | |
| 'showarrow': False, 'font': {'size': 12, 'color': '#f39c12'}}, | |
| ] | |
| }, | |
| # Variational Autoencoder | |
| 'vae_architecture': { | |
| 'base_data': [ | |
| # Encoder | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.15], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 80, 'color': COLORS['primary'], 'symbol': 'square'}, | |
| 'text': ['Encoder'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 14, 'color': 'white'}, | |
| 'name': 'Encoder', | |
| 'hovertemplate': '<b>Encoder</b><br>Compresses input to latent distribution<br>Outputs μ (mean) and σ (std)<extra></extra>' | |
| }, | |
| # Latent space (mu and sigma) | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.45, 0.45], 'y': [0.65, 0.35], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 50, 'color': [COLORS['info'], COLORS['warning']], 'symbol': 'circle'}, | |
| 'text': ['μ', 'σ'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 20, 'color': 'white'}, | |
| 'name': 'Latent Params', | |
| 'hovertemplate': '%{text}: Latent parameter<extra></extra>' | |
| }, | |
| # Sampling (z) | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.6], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 55, 'color': COLORS['accent'], 'symbol': 'diamond'}, | |
| 'text': ['z'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 20, 'color': 'white'}, | |
| 'name': 'Latent Code', | |
| 'hovertemplate': '<b>Latent Code (z)</b><br>z = μ + σ * ε<br>ε ~ N(0,1) (reparameterization trick)<extra></extra>' | |
| }, | |
| # Decoder | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.85], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 80, 'color': COLORS['success'], 'symbol': 'square'}, | |
| 'text': ['Decoder'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 14, 'color': 'white'}, | |
| 'name': 'Decoder', | |
| 'hovertemplate': '<b>Decoder</b><br>Reconstructs input from latent code<extra></extra>' | |
| }, | |
| ], | |
| 'layout': { | |
| 'xaxis': {'visible': False, 'range': [0, 1]}, | |
| 'yaxis': {'visible': False, 'range': [0, 1]}, | |
| 'showlegend': False, | |
| 'height': 450, | |
| 'shapes': [ | |
| # Input arrow | |
| {'type': 'line', 'x0': 0.02, 'y0': 0.5, 'x1': 0.08, 'y1': 0.5, | |
| 'line': {'color': 'white', 'width': 3}}, | |
| # Encoder to mu | |
| {'type': 'line', 'x0': 0.22, 'y0': 0.55, 'x1': 0.4, 'y1': 0.65, | |
| 'line': {'color': COLORS['primary'], 'width': 2}}, | |
| # Encoder to sigma | |
| {'type': 'line', 'x0': 0.22, 'y0': 0.45, 'x1': 0.4, 'y1': 0.35, | |
| 'line': {'color': COLORS['primary'], 'width': 2}}, | |
| # mu to z | |
| {'type': 'line', 'x0': 0.5, 'y0': 0.62, 'x1': 0.55, 'y1': 0.52, | |
| 'line': {'color': COLORS['info'], 'width': 2}}, | |
| # sigma to z | |
| {'type': 'line', 'x0': 0.5, 'y0': 0.38, 'x1': 0.55, 'y1': 0.48, | |
| 'line': {'color': COLORS['warning'], 'width': 2}}, | |
| # z to decoder | |
| {'type': 'line', 'x0': 0.65, 'y0': 0.5, 'x1': 0.78, 'y1': 0.5, | |
| 'line': {'color': COLORS['accent'], 'width': 3}}, | |
| # Output arrow | |
| {'type': 'line', 'x0': 0.92, 'y0': 0.5, 'x1': 0.98, 'y1': 0.5, | |
| 'line': {'color': 'white', 'width': 3}}, | |
| ], | |
| }, | |
| 'annotations': [ | |
| {'x': 0.02, 'y': 0.58, 'text': 'Input x', 'showarrow': False, | |
| 'font': {'size': 11, 'color': 'white'}}, | |
| {'x': 0.98, 'y': 0.58, 'text': 'Output x̂', 'showarrow': False, | |
| 'font': {'size': 11, 'color': 'white'}}, | |
| {'x': 0.5, 'y': 0.15, 'text': 'Loss = Reconstruction + KL Divergence', | |
| 'showarrow': False, 'font': {'size': 12, 'color': 'white'}, | |
| 'bgcolor': 'rgba(102, 126, 234, 0.8)', 'borderpad': 6}, | |
| ] | |
| }, | |
| # GAN Architecture | |
| 'gan_architecture': { | |
| 'base_data': [ | |
| # Random noise | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.1], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 45, 'color': '#95a5a6', 'symbol': 'circle'}, | |
| 'text': ['z'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 18, 'color': 'white'}, | |
| 'name': 'Noise', | |
| 'hovertemplate': '<b>Random Noise (z)</b><br>Sampled from normal distribution<extra></extra>' | |
| }, | |
| # Generator | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.3], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 70, 'color': COLORS['success'], 'symbol': 'square'}, | |
| 'text': ['G'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 24, 'color': 'white'}, | |
| 'name': 'Generator', | |
| 'hovertemplate': '<b>Generator (G)</b><br>Creates fake data from noise<br>Tries to fool the Discriminator<extra></extra>' | |
| }, | |
| # Generated data | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.5], 'y': [0.7], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 50, 'color': COLORS['warning'], 'symbol': 'diamond'}, | |
| 'text': ['Fake'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 12, 'color': 'white'}, | |
| 'name': 'Fake Data', | |
| 'hovertemplate': '<b>Generated (Fake) Data</b><br>Output from Generator<extra></extra>' | |
| }, | |
| # Real data | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.5], 'y': [0.3], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 50, 'color': COLORS['info'], 'symbol': 'diamond'}, | |
| 'text': ['Real'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 12, 'color': 'white'}, | |
| 'name': 'Real Data', | |
| 'hovertemplate': '<b>Real Data</b><br>From training dataset<extra></extra>' | |
| }, | |
| # Discriminator | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.75], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 70, 'color': COLORS['primary'], 'symbol': 'square'}, | |
| 'text': ['D'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 24, 'color': 'white'}, | |
| 'name': 'Discriminator', | |
| 'hovertemplate': '<b>Discriminator (D)</b><br>Classifies real vs fake<br>Tries to catch the Generator<extra></extra>' | |
| }, | |
| # Output | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.92], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 40, 'color': '#ecf0f1', 'symbol': 'circle'}, | |
| 'text': ['0/1'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 12, 'color': '#333'}, | |
| 'name': 'Output', | |
| 'hovertemplate': '<b>Classification</b><br>1 = Real, 0 = Fake<extra></extra>' | |
| }, | |
| ], | |
| 'layout': { | |
| 'xaxis': {'visible': False, 'range': [0, 1]}, | |
| 'yaxis': {'visible': False, 'range': [0, 1]}, | |
| 'showlegend': False, | |
| 'height': 450, | |
| 'shapes': [ | |
| # Noise to G | |
| {'type': 'line', 'x0': 0.15, 'y0': 0.7, 'x1': 0.23, 'y1': 0.7, | |
| 'line': {'color': '#95a5a6', 'width': 3}}, | |
| # G to Fake | |
| {'type': 'line', 'x0': 0.37, 'y0': 0.7, 'x1': 0.45, 'y1': 0.7, | |
| 'line': {'color': COLORS['success'], 'width': 3}}, | |
| # Fake to D | |
| {'type': 'line', 'x0': 0.55, 'y0': 0.65, 'x1': 0.68, 'y1': 0.55, | |
| 'line': {'color': COLORS['warning'], 'width': 3}}, | |
| # Real to D | |
| {'type': 'line', 'x0': 0.55, 'y0': 0.35, 'x1': 0.68, 'y1': 0.45, | |
| 'line': {'color': COLORS['info'], 'width': 3}}, | |
| # D to output | |
| {'type': 'line', 'x0': 0.82, 'y0': 0.5, 'x1': 0.88, 'y1': 0.5, | |
| 'line': {'color': COLORS['primary'], 'width': 3}}, | |
| ], | |
| }, | |
| 'annotations': [ | |
| {'x': 0.3, 'y': 0.88, 'text': '<b>Generator</b><br>Creates fakes', | |
| 'showarrow': False, 'font': {'size': 11, 'color': COLORS['success']}}, | |
| {'x': 0.75, 'y': 0.88, 'text': '<b>Discriminator</b><br>Detects fakes', | |
| 'showarrow': False, 'font': {'size': 11, 'color': COLORS['primary']}}, | |
| {'x': 0.5, 'y': 0.08, 'text': 'Adversarial Training: G tries to fool D, D tries to catch G', | |
| 'showarrow': False, 'font': {'size': 12, 'color': 'white'}, | |
| 'bgcolor': 'rgba(102, 126, 234, 0.8)', 'borderpad': 6}, | |
| ] | |
| }, | |
| # CNN Architecture | |
| 'cnn_architecture': { | |
| 'base_data': [ | |
| # Input image | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.08], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 60, 'color': '#3498db', 'symbol': 'square'}, | |
| 'text': ['Input'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 11, 'color': 'white'}, | |
| 'name': 'Input', | |
| 'hovertemplate': '<b>Input Image</b><br>Raw pixel values<extra></extra>' | |
| }, | |
| # Conv layer 1 | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.25], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 55, 'color': '#e74c3c', 'symbol': 'square'}, | |
| 'text': ['Conv1'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 10, 'color': 'white'}, | |
| 'name': 'Conv1', | |
| 'hovertemplate': '<b>Convolution Layer 1</b><br>Detects edges and simple patterns<extra></extra>' | |
| }, | |
| # Pool layer 1 | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.38], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 45, 'color': '#9b59b6', 'symbol': 'circle'}, | |
| 'text': ['Pool'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 9, 'color': 'white'}, | |
| 'name': 'Pool1', | |
| 'hovertemplate': '<b>Pooling Layer</b><br>Reduces spatial size<br>Keeps important features<extra></extra>' | |
| }, | |
| # Conv layer 2 | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.52], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 50, 'color': '#e74c3c', 'symbol': 'square'}, | |
| 'text': ['Conv2'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 10, 'color': 'white'}, | |
| 'name': 'Conv2', | |
| 'hovertemplate': '<b>Convolution Layer 2</b><br>Detects complex patterns<extra></extra>' | |
| }, | |
| # Pool layer 2 | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.65], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 40, 'color': '#9b59b6', 'symbol': 'circle'}, | |
| 'text': ['Pool'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 9, 'color': 'white'}, | |
| 'name': 'Pool2', | |
| }, | |
| # Flatten | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.78], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 35, 'color': '#f39c12', 'symbol': 'diamond'}, | |
| 'text': ['Flat'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 9, 'color': 'white'}, | |
| 'name': 'Flatten', | |
| 'hovertemplate': '<b>Flatten</b><br>Convert 2D features to 1D vector<extra></extra>' | |
| }, | |
| # FC layer | |
| { | |
| 'type': 'scatter', | |
| 'x': [0.9], 'y': [0.5], | |
| 'mode': 'markers+text', | |
| 'marker': {'size': 45, 'color': '#2ecc71', 'symbol': 'square'}, | |
| 'text': ['FC'], | |
| 'textposition': 'middle center', | |
| 'textfont': {'size': 11, 'color': 'white'}, | |
| 'name': 'Fully Connected', | |
| 'hovertemplate': '<b>Fully Connected</b><br>Final classification layer<extra></extra>' | |
| }, | |
| ], | |
| 'layout': { | |
| 'xaxis': {'visible': False, 'range': [0, 1]}, | |
| 'yaxis': {'visible': False, 'range': [0, 1]}, | |
| 'showlegend': False, | |
| 'height': 400, | |
| }, | |
| 'annotations': [ | |
| {'x': 0.5, 'y': 0.85, 'text': '<b>CNN: Convolutional Neural Network</b>', | |
| 'showarrow': False, 'font': {'size': 14, 'color': 'white'}}, | |
| {'x': 0.3, 'y': 0.2, 'text': 'Feature Extraction', | |
| 'showarrow': False, 'font': {'size': 11, 'color': '#e74c3c'}}, | |
| {'x': 0.85, 'y': 0.2, 'text': 'Classification', | |
| 'showarrow': False, 'font': {'size': 11, 'color': '#2ecc71'}}, | |
| ] | |
| }, | |
| } | |
| def get_template(name: str) -> dict: | |
| """Get a template by name, or None if not found.""" | |
| return TEMPLATES.get(name) | |
| def list_templates() -> list: | |
| """List all available template names.""" | |
| return list(TEMPLATES.keys()) | |