AI-Learning-Playground / templates.py
adi-123's picture
Upload 5 files
888531f verified
"""
Template Library for Complex ML Architecture Visualizations
These templates provide base Plotly JSON specs for complex ML concepts
that the LLM can reference and customize.
"""
COLORS = {
'primary': '#667eea',
'secondary': '#764ba2',
'accent': '#f093fb',
'success': '#43e97b',
'warning': '#fa709a',
'info': '#4facfe',
'query': '#ff6b6b',
'key': '#4ecdc4',
'value': '#45b7d1',
}
TEMPLATES = {
# Transformer Attention Q/K/V Architecture
'transformer_attention': {
'base_data': [
# Query box
{
'type': 'scatter',
'x': [0.15], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 60, 'color': COLORS['query'], 'symbol': 'square'},
'text': ['Q'],
'textposition': 'middle center',
'textfont': {'size': 24, 'color': 'white'},
'name': 'Query',
'hovertemplate': '<b>Query (Q)</b><br>Represents "what am I looking for?"<extra></extra>'
},
# Key box
{
'type': 'scatter',
'x': [0.5], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 60, 'color': COLORS['key'], 'symbol': 'square'},
'text': ['K'],
'textposition': 'middle center',
'textfont': {'size': 24, 'color': 'white'},
'name': 'Key',
'hovertemplate': '<b>Key (K)</b><br>Represents "what do I contain?"<extra></extra>'
},
# Value box
{
'type': 'scatter',
'x': [0.85], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 60, 'color': COLORS['value'], 'symbol': 'square'},
'text': ['V'],
'textposition': 'middle center',
'textfont': {'size': 24, 'color': 'white'},
'name': 'Value',
'hovertemplate': '<b>Value (V)</b><br>The actual information to retrieve<extra></extra>'
},
# Attention scores (Q·K)
{
'type': 'scatter',
'x': [0.325], 'y': [0.4],
'mode': 'markers+text',
'marker': {'size': 50, 'color': COLORS['accent'], 'symbol': 'diamond'},
'text': ['Q·Kᵀ'],
'textposition': 'middle center',
'textfont': {'size': 16, 'color': 'white'},
'name': 'Attention Scores',
'hovertemplate': '<b>Attention Scores</b><br>Dot product of Q and K<br>Shows how much each token attends to others<extra></extra>'
},
# Softmax
{
'type': 'scatter',
'x': [0.325], 'y': [0.2],
'mode': 'markers+text',
'marker': {'size': 45, 'color': COLORS['warning'], 'symbol': 'circle'},
'text': ['softmax'],
'textposition': 'middle center',
'textfont': {'size': 12, 'color': 'white'},
'name': 'Softmax',
'hovertemplate': '<b>Softmax</b><br>Normalizes scores to probabilities (sum to 1)<extra></extra>'
},
# Output
{
'type': 'scatter',
'x': [0.675], 'y': [0.2],
'mode': 'markers+text',
'marker': {'size': 50, 'color': COLORS['success'], 'symbol': 'square'},
'text': ['Output'],
'textposition': 'middle center',
'textfont': {'size': 14, 'color': 'white'},
'name': 'Output',
'hovertemplate': '<b>Attention Output</b><br>Weighted sum of Values<br>Based on attention weights<extra></extra>'
},
],
'layout': {
'xaxis': {'visible': False, 'range': [-0.1, 1.1]},
'yaxis': {'visible': False, 'range': [-0.1, 1.0]},
'showlegend': False,
'height': 500,
'shapes': [
# Arrow Q to Q·K
{'type': 'line', 'x0': 0.15, 'y0': 0.62, 'x1': 0.28, 'y1': 0.45,
'line': {'color': COLORS['query'], 'width': 3}},
# Arrow K to Q·K
{'type': 'line', 'x0': 0.5, 'y0': 0.62, 'x1': 0.37, 'y1': 0.45,
'line': {'color': COLORS['key'], 'width': 3}},
# Arrow Q·K to softmax
{'type': 'line', 'x0': 0.325, 'y0': 0.33, 'x1': 0.325, 'y1': 0.25,
'line': {'color': COLORS['accent'], 'width': 3}},
# Arrow softmax to output
{'type': 'line', 'x0': 0.38, 'y0': 0.2, 'x1': 0.6, 'y1': 0.2,
'line': {'color': COLORS['warning'], 'width': 3}},
# Arrow V to output
{'type': 'line', 'x0': 0.85, 'y0': 0.62, 'x1': 0.72, 'y1': 0.25,
'line': {'color': COLORS['value'], 'width': 3}},
],
},
'annotations': [
{'x': 0.15, 'y': 0.85, 'text': '<b>Query</b><br>What am I looking for?',
'showarrow': False, 'font': {'size': 11, 'color': COLORS['query']}},
{'x': 0.5, 'y': 0.85, 'text': '<b>Key</b><br>What do I contain?',
'showarrow': False, 'font': {'size': 11, 'color': COLORS['key']}},
{'x': 0.85, 'y': 0.85, 'text': '<b>Value</b><br>Actual content',
'showarrow': False, 'font': {'size': 11, 'color': COLORS['value']}},
{'x': 0.5, 'y': 0.02, 'text': 'Attention(Q,K,V) = softmax(QKᵀ/√d)V',
'showarrow': False, 'font': {'size': 14, 'color': 'white'},
'bgcolor': 'rgba(102, 126, 234, 0.8)', 'borderpad': 8},
]
},
# LSTM Gate Architecture
'lstm_gates': {
'base_data': [
# Forget gate
{
'type': 'scatter',
'x': [0.2], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 55, 'color': '#e74c3c', 'symbol': 'square'},
'text': ['f'],
'textposition': 'middle center',
'textfont': {'size': 20, 'color': 'white'},
'name': 'Forget Gate',
'hovertemplate': '<b>Forget Gate (f)</b><br>Decides what to forget from cell state<br>f = σ(W·[h, x] + b)<extra></extra>'
},
# Input gate
{
'type': 'scatter',
'x': [0.4], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 55, 'color': '#2ecc71', 'symbol': 'square'},
'text': ['i'],
'textposition': 'middle center',
'textfont': {'size': 20, 'color': 'white'},
'name': 'Input Gate',
'hovertemplate': '<b>Input Gate (i)</b><br>Decides what new info to store<br>i = σ(W·[h, x] + b)<extra></extra>'
},
# Cell state candidate
{
'type': 'scatter',
'x': [0.6], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 55, 'color': '#9b59b6', 'symbol': 'square'},
'text': ['C̃'],
'textposition': 'middle center',
'textfont': {'size': 20, 'color': 'white'},
'name': 'Candidate',
'hovertemplate': '<b>Candidate Cell State (C̃)</b><br>New candidate values<br>C̃ = tanh(W·[h, x] + b)<extra></extra>'
},
# Output gate
{
'type': 'scatter',
'x': [0.8], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 55, 'color': '#3498db', 'symbol': 'square'},
'text': ['o'],
'textposition': 'middle center',
'textfont': {'size': 20, 'color': 'white'},
'name': 'Output Gate',
'hovertemplate': '<b>Output Gate (o)</b><br>Decides what to output<br>o = σ(W·[h, x] + b)<extra></extra>'
},
# Cell state line
{
'type': 'scatter',
'x': [0.1, 0.9], 'y': [0.4, 0.4],
'mode': 'lines',
'line': {'color': '#f39c12', 'width': 8},
'name': 'Cell State',
'hovertemplate': '<b>Cell State (C)</b><br>The "memory" that flows through time<extra></extra>'
},
# Hidden state output
{
'type': 'scatter',
'x': [0.8], 'y': [0.15],
'mode': 'markers+text',
'marker': {'size': 45, 'color': '#1abc9c', 'symbol': 'circle'},
'text': ['hₜ'],
'textposition': 'middle center',
'textfont': {'size': 16, 'color': 'white'},
'name': 'Hidden State',
'hovertemplate': '<b>Hidden State (hₜ)</b><br>Output at time t<br>hₜ = o * tanh(Cₜ)<extra></extra>'
},
],
'layout': {
'xaxis': {'visible': False, 'range': [0, 1]},
'yaxis': {'visible': False, 'range': [0, 1]},
'showlegend': False,
'height': 500,
},
'annotations': [
{'x': 0.2, 'y': 0.88, 'text': '<b>Forget</b>', 'showarrow': False,
'font': {'size': 12, 'color': '#e74c3c'}},
{'x': 0.4, 'y': 0.88, 'text': '<b>Input</b>', 'showarrow': False,
'font': {'size': 12, 'color': '#2ecc71'}},
{'x': 0.6, 'y': 0.88, 'text': '<b>Candidate</b>', 'showarrow': False,
'font': {'size': 12, 'color': '#9b59b6'}},
{'x': 0.8, 'y': 0.88, 'text': '<b>Output</b>', 'showarrow': False,
'font': {'size': 12, 'color': '#3498db'}},
{'x': 0.5, 'y': 0.4, 'text': '← Cell State (Long-term memory) →',
'showarrow': False, 'font': {'size': 12, 'color': '#f39c12'}},
]
},
# Variational Autoencoder
'vae_architecture': {
'base_data': [
# Encoder
{
'type': 'scatter',
'x': [0.15], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 80, 'color': COLORS['primary'], 'symbol': 'square'},
'text': ['Encoder'],
'textposition': 'middle center',
'textfont': {'size': 14, 'color': 'white'},
'name': 'Encoder',
'hovertemplate': '<b>Encoder</b><br>Compresses input to latent distribution<br>Outputs μ (mean) and σ (std)<extra></extra>'
},
# Latent space (mu and sigma)
{
'type': 'scatter',
'x': [0.45, 0.45], 'y': [0.65, 0.35],
'mode': 'markers+text',
'marker': {'size': 50, 'color': [COLORS['info'], COLORS['warning']], 'symbol': 'circle'},
'text': ['μ', 'σ'],
'textposition': 'middle center',
'textfont': {'size': 20, 'color': 'white'},
'name': 'Latent Params',
'hovertemplate': '%{text}: Latent parameter<extra></extra>'
},
# Sampling (z)
{
'type': 'scatter',
'x': [0.6], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 55, 'color': COLORS['accent'], 'symbol': 'diamond'},
'text': ['z'],
'textposition': 'middle center',
'textfont': {'size': 20, 'color': 'white'},
'name': 'Latent Code',
'hovertemplate': '<b>Latent Code (z)</b><br>z = μ + σ * ε<br>ε ~ N(0,1) (reparameterization trick)<extra></extra>'
},
# Decoder
{
'type': 'scatter',
'x': [0.85], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 80, 'color': COLORS['success'], 'symbol': 'square'},
'text': ['Decoder'],
'textposition': 'middle center',
'textfont': {'size': 14, 'color': 'white'},
'name': 'Decoder',
'hovertemplate': '<b>Decoder</b><br>Reconstructs input from latent code<extra></extra>'
},
],
'layout': {
'xaxis': {'visible': False, 'range': [0, 1]},
'yaxis': {'visible': False, 'range': [0, 1]},
'showlegend': False,
'height': 450,
'shapes': [
# Input arrow
{'type': 'line', 'x0': 0.02, 'y0': 0.5, 'x1': 0.08, 'y1': 0.5,
'line': {'color': 'white', 'width': 3}},
# Encoder to mu
{'type': 'line', 'x0': 0.22, 'y0': 0.55, 'x1': 0.4, 'y1': 0.65,
'line': {'color': COLORS['primary'], 'width': 2}},
# Encoder to sigma
{'type': 'line', 'x0': 0.22, 'y0': 0.45, 'x1': 0.4, 'y1': 0.35,
'line': {'color': COLORS['primary'], 'width': 2}},
# mu to z
{'type': 'line', 'x0': 0.5, 'y0': 0.62, 'x1': 0.55, 'y1': 0.52,
'line': {'color': COLORS['info'], 'width': 2}},
# sigma to z
{'type': 'line', 'x0': 0.5, 'y0': 0.38, 'x1': 0.55, 'y1': 0.48,
'line': {'color': COLORS['warning'], 'width': 2}},
# z to decoder
{'type': 'line', 'x0': 0.65, 'y0': 0.5, 'x1': 0.78, 'y1': 0.5,
'line': {'color': COLORS['accent'], 'width': 3}},
# Output arrow
{'type': 'line', 'x0': 0.92, 'y0': 0.5, 'x1': 0.98, 'y1': 0.5,
'line': {'color': 'white', 'width': 3}},
],
},
'annotations': [
{'x': 0.02, 'y': 0.58, 'text': 'Input x', 'showarrow': False,
'font': {'size': 11, 'color': 'white'}},
{'x': 0.98, 'y': 0.58, 'text': 'Output x̂', 'showarrow': False,
'font': {'size': 11, 'color': 'white'}},
{'x': 0.5, 'y': 0.15, 'text': 'Loss = Reconstruction + KL Divergence',
'showarrow': False, 'font': {'size': 12, 'color': 'white'},
'bgcolor': 'rgba(102, 126, 234, 0.8)', 'borderpad': 6},
]
},
# GAN Architecture
'gan_architecture': {
'base_data': [
# Random noise
{
'type': 'scatter',
'x': [0.1], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 45, 'color': '#95a5a6', 'symbol': 'circle'},
'text': ['z'],
'textposition': 'middle center',
'textfont': {'size': 18, 'color': 'white'},
'name': 'Noise',
'hovertemplate': '<b>Random Noise (z)</b><br>Sampled from normal distribution<extra></extra>'
},
# Generator
{
'type': 'scatter',
'x': [0.3], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 70, 'color': COLORS['success'], 'symbol': 'square'},
'text': ['G'],
'textposition': 'middle center',
'textfont': {'size': 24, 'color': 'white'},
'name': 'Generator',
'hovertemplate': '<b>Generator (G)</b><br>Creates fake data from noise<br>Tries to fool the Discriminator<extra></extra>'
},
# Generated data
{
'type': 'scatter',
'x': [0.5], 'y': [0.7],
'mode': 'markers+text',
'marker': {'size': 50, 'color': COLORS['warning'], 'symbol': 'diamond'},
'text': ['Fake'],
'textposition': 'middle center',
'textfont': {'size': 12, 'color': 'white'},
'name': 'Fake Data',
'hovertemplate': '<b>Generated (Fake) Data</b><br>Output from Generator<extra></extra>'
},
# Real data
{
'type': 'scatter',
'x': [0.5], 'y': [0.3],
'mode': 'markers+text',
'marker': {'size': 50, 'color': COLORS['info'], 'symbol': 'diamond'},
'text': ['Real'],
'textposition': 'middle center',
'textfont': {'size': 12, 'color': 'white'},
'name': 'Real Data',
'hovertemplate': '<b>Real Data</b><br>From training dataset<extra></extra>'
},
# Discriminator
{
'type': 'scatter',
'x': [0.75], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 70, 'color': COLORS['primary'], 'symbol': 'square'},
'text': ['D'],
'textposition': 'middle center',
'textfont': {'size': 24, 'color': 'white'},
'name': 'Discriminator',
'hovertemplate': '<b>Discriminator (D)</b><br>Classifies real vs fake<br>Tries to catch the Generator<extra></extra>'
},
# Output
{
'type': 'scatter',
'x': [0.92], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 40, 'color': '#ecf0f1', 'symbol': 'circle'},
'text': ['0/1'],
'textposition': 'middle center',
'textfont': {'size': 12, 'color': '#333'},
'name': 'Output',
'hovertemplate': '<b>Classification</b><br>1 = Real, 0 = Fake<extra></extra>'
},
],
'layout': {
'xaxis': {'visible': False, 'range': [0, 1]},
'yaxis': {'visible': False, 'range': [0, 1]},
'showlegend': False,
'height': 450,
'shapes': [
# Noise to G
{'type': 'line', 'x0': 0.15, 'y0': 0.7, 'x1': 0.23, 'y1': 0.7,
'line': {'color': '#95a5a6', 'width': 3}},
# G to Fake
{'type': 'line', 'x0': 0.37, 'y0': 0.7, 'x1': 0.45, 'y1': 0.7,
'line': {'color': COLORS['success'], 'width': 3}},
# Fake to D
{'type': 'line', 'x0': 0.55, 'y0': 0.65, 'x1': 0.68, 'y1': 0.55,
'line': {'color': COLORS['warning'], 'width': 3}},
# Real to D
{'type': 'line', 'x0': 0.55, 'y0': 0.35, 'x1': 0.68, 'y1': 0.45,
'line': {'color': COLORS['info'], 'width': 3}},
# D to output
{'type': 'line', 'x0': 0.82, 'y0': 0.5, 'x1': 0.88, 'y1': 0.5,
'line': {'color': COLORS['primary'], 'width': 3}},
],
},
'annotations': [
{'x': 0.3, 'y': 0.88, 'text': '<b>Generator</b><br>Creates fakes',
'showarrow': False, 'font': {'size': 11, 'color': COLORS['success']}},
{'x': 0.75, 'y': 0.88, 'text': '<b>Discriminator</b><br>Detects fakes',
'showarrow': False, 'font': {'size': 11, 'color': COLORS['primary']}},
{'x': 0.5, 'y': 0.08, 'text': 'Adversarial Training: G tries to fool D, D tries to catch G',
'showarrow': False, 'font': {'size': 12, 'color': 'white'},
'bgcolor': 'rgba(102, 126, 234, 0.8)', 'borderpad': 6},
]
},
# CNN Architecture
'cnn_architecture': {
'base_data': [
# Input image
{
'type': 'scatter',
'x': [0.08], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 60, 'color': '#3498db', 'symbol': 'square'},
'text': ['Input'],
'textposition': 'middle center',
'textfont': {'size': 11, 'color': 'white'},
'name': 'Input',
'hovertemplate': '<b>Input Image</b><br>Raw pixel values<extra></extra>'
},
# Conv layer 1
{
'type': 'scatter',
'x': [0.25], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 55, 'color': '#e74c3c', 'symbol': 'square'},
'text': ['Conv1'],
'textposition': 'middle center',
'textfont': {'size': 10, 'color': 'white'},
'name': 'Conv1',
'hovertemplate': '<b>Convolution Layer 1</b><br>Detects edges and simple patterns<extra></extra>'
},
# Pool layer 1
{
'type': 'scatter',
'x': [0.38], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 45, 'color': '#9b59b6', 'symbol': 'circle'},
'text': ['Pool'],
'textposition': 'middle center',
'textfont': {'size': 9, 'color': 'white'},
'name': 'Pool1',
'hovertemplate': '<b>Pooling Layer</b><br>Reduces spatial size<br>Keeps important features<extra></extra>'
},
# Conv layer 2
{
'type': 'scatter',
'x': [0.52], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 50, 'color': '#e74c3c', 'symbol': 'square'},
'text': ['Conv2'],
'textposition': 'middle center',
'textfont': {'size': 10, 'color': 'white'},
'name': 'Conv2',
'hovertemplate': '<b>Convolution Layer 2</b><br>Detects complex patterns<extra></extra>'
},
# Pool layer 2
{
'type': 'scatter',
'x': [0.65], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 40, 'color': '#9b59b6', 'symbol': 'circle'},
'text': ['Pool'],
'textposition': 'middle center',
'textfont': {'size': 9, 'color': 'white'},
'name': 'Pool2',
},
# Flatten
{
'type': 'scatter',
'x': [0.78], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 35, 'color': '#f39c12', 'symbol': 'diamond'},
'text': ['Flat'],
'textposition': 'middle center',
'textfont': {'size': 9, 'color': 'white'},
'name': 'Flatten',
'hovertemplate': '<b>Flatten</b><br>Convert 2D features to 1D vector<extra></extra>'
},
# FC layer
{
'type': 'scatter',
'x': [0.9], 'y': [0.5],
'mode': 'markers+text',
'marker': {'size': 45, 'color': '#2ecc71', 'symbol': 'square'},
'text': ['FC'],
'textposition': 'middle center',
'textfont': {'size': 11, 'color': 'white'},
'name': 'Fully Connected',
'hovertemplate': '<b>Fully Connected</b><br>Final classification layer<extra></extra>'
},
],
'layout': {
'xaxis': {'visible': False, 'range': [0, 1]},
'yaxis': {'visible': False, 'range': [0, 1]},
'showlegend': False,
'height': 400,
},
'annotations': [
{'x': 0.5, 'y': 0.85, 'text': '<b>CNN: Convolutional Neural Network</b>',
'showarrow': False, 'font': {'size': 14, 'color': 'white'}},
{'x': 0.3, 'y': 0.2, 'text': 'Feature Extraction',
'showarrow': False, 'font': {'size': 11, 'color': '#e74c3c'}},
{'x': 0.85, 'y': 0.2, 'text': 'Classification',
'showarrow': False, 'font': {'size': 11, 'color': '#2ecc71'}},
]
},
}
def get_template(name: str) -> dict:
"""Get a template by name, or None if not found."""
return TEMPLATES.get(name)
def list_templates() -> list:
"""List all available template names."""
return list(TEMPLATES.keys())