adi-123's picture
Upload 5 files
54dac2f verified
"""
AI Learning Playground - Interactive AI Concept Visualizations
See AI thinking. From neural networks to transformers.
"""
import gradio as gr
import os
from dotenv import load_dotenv
from generator import VisualizationGenerator
load_dotenv()
generator = None
current_config = None
current_topic = None
API_KEY = os.getenv("TOGETHER_API_KEY", "")
if API_KEY and API_KEY != "your-api-key-here":
try:
generator = VisualizationGenerator(api_key=API_KEY)
print("Connected to Together AI")
except Exception as e:
print(f"Connection failed: {e}")
def format_evolution(evolution_data: dict) -> str:
"""Format evolution/history section."""
if not evolution_data:
return ""
predecessor = evolution_data.get('predecessor', '')
problem = evolution_data.get('predecessor_problem', '')
solution = evolution_data.get('how_it_solves', '')
innovation = evolution_data.get('key_innovation', '')
if not any([predecessor, problem, solution, innovation]):
return ""
parts = []
if predecessor and predecessor.lower() != 'none':
parts.append(f"**Before:** {predecessor}")
if problem:
parts.append(f"\n**The Problem:** {problem}")
elif predecessor.lower() == 'none':
parts.append("**Foundational Concept**")
if problem:
parts.append(f"\n**Problem Solved:** {problem}")
if solution:
parts.append(f"\n**The Solution:** {solution}")
if innovation:
parts.append(f"\n**Key Innovation:** {innovation}")
return '\n'.join(parts)
def format_math(math_data: dict) -> str:
"""Format math with explanations."""
if not math_data:
return ""
parts = []
formulas = math_data.get('formulas', [])
if formulas:
parts.append("## Key Equations\n")
for f in formulas:
name = f.get('name', '')
eq = f.get('equation', '')
desc = f.get('description', '')
if name:
parts.append(f"### {name}")
if eq:
parts.append(f"\n{eq}\n")
if desc:
parts.append(f"*{desc}*\n")
formula = math_data.get('formula', '')
if formula and not formulas:
parts.append(f"## Formula\n\n{formula}\n")
variables = math_data.get('variables', [])
if variables:
parts.append("\n## What Each Symbol Means\n")
for var in variables:
symbol = var.get('symbol', '')
# Wrap in $...$ for LaTeX rendering if not already wrapped
if symbol and not symbol.startswith('$'):
symbol = f"${symbol}$"
parts.append(f"- {symbol}{var.get('meaning', '')}")
return '\n'.join(parts)
def generate_visualization(topic: str):
"""Generate multi-view visualization."""
global current_config, current_topic
empty_result = [
None, None, None, # 3 plot outputs
"", "", "", "", "", "", # text outputs (title, oneliner, intuition, why_it_matters, evolution, math)
gr.update(visible=False), # slider group
] + [gr.update(visible=False)] * 4 + [gr.update(visible=False)] # sliders + button
if not generator or not topic.strip():
return empty_result
try:
result = generator.generate(topic)
current_config = result
current_topic = topic
figures = result.get('figures', [])
fig1 = figures[0] if len(figures) > 0 else None
fig2 = figures[1] if len(figures) > 1 else None
fig3 = figures[2] if len(figures) > 2 else None
# Build slider updates - filter to only valid numeric params
raw_params = result.get('params', [])
params = []
for p in raw_params:
try:
# Validate that min/max/default are numeric
min_val = p.get('min')
max_val = p.get('max')
default_val = p.get('default')
# Skip params with non-numeric values
if min_val is not None and not isinstance(min_val, (int, float)):
float(min_val) # This will raise if not convertible
if max_val is not None and not isinstance(max_val, (int, float)):
float(max_val)
if default_val is not None and not isinstance(default_val, (int, float)):
float(default_val)
params.append(p)
except (ValueError, TypeError):
print(f"Skipping invalid param: {p}")
continue
slider_updates = []
for i in range(4):
if i < len(params):
p = params[i]
# Handle None values from LLM (use 'or' to catch both missing and null)
min_val = float(p.get('min') or 1)
max_val = float(p.get('max') or 10)
default_val = float(p.get('default') or min_val)
step_val = float(p.get('step') or 1)
# Clamp default to valid range
default_val = max(min_val, min(max_val, default_val))
slider_updates.append(gr.update(
visible=True,
label=p.get('label') or p.get('name', 'Parameter'),
minimum=min_val, maximum=max_val,
value=default_val,
step=step_val,
))
else:
slider_updates.append(gr.update(visible=False))
return [
fig1, fig2, fig3,
f"# {result.get('title', topic)}",
result.get('oneliner', ''),
result.get('intuition', ''),
result.get('why_it_matters', ''),
format_evolution(result.get('evolution', {})),
format_math(result.get('math', {})),
gr.update(visible=bool(params)),
] + slider_updates + [gr.update(visible=bool(params))]
except Exception as e:
import traceback
traceback.print_exc()
return [None, None, None, f"# Error: {str(e)}", "", "", "", "", "", gr.update(visible=False)] + [gr.update(visible=False)] * 5
def apply_params(s1, s2, s3, s4):
"""Update all visualizations with new params."""
global current_config, current_topic
if not current_config or not generator:
return None, None, None
params = current_config.get('params', [])
if not params:
return None, None, None
param_values = {}
slider_vals = [s1, s2, s3, s4]
for i, p in enumerate(params[:4]):
if i < len(slider_vals) and slider_vals[i] is not None:
param_values[p['name']] = slider_vals[i]
try:
figures = generator.update_params(current_topic, param_values)
fig1 = figures[0] if len(figures) > 0 else None
fig2 = figures[1] if len(figures) > 1 else None
fig3 = figures[2] if len(figures) > 2 else None
return fig1, fig2, fig3
except Exception as e:
print(f"Update error: {e}")
return None, None, None
# Theme
app_theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="purple", neutral_hue="slate")
custom_css = """
.gradio-container { max-width: 1600px !important; }
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 1.5rem 2rem;
border-radius: 16px;
margin-bottom: 1rem;
text-align: center;
}
.header h1 { color: white; margin: 0; font-size: 2rem; }
.header p { color: rgba(255,255,255,0.9); margin: 0.3rem 0 0 0; font-size: 1.1rem; }
.viz-grid { display: grid; gap: 1rem; }
.card {
background: white;
border-radius: 12px;
padding: 1rem;
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
}
.oneliner {
background: linear-gradient(135deg, #667eea, #764ba2);
color: white !important;
padding: 1rem;
border-radius: 10px;
font-size: 1.1rem;
text-align: center;
}
.oneliner p { color: white !important; margin: 0; }
.intuition {
background: #f0f4ff;
border-left: 4px solid #667eea;
padding: 1rem;
border-radius: 0 8px 8px 0;
color: #333 !important;
}
.intuition p { color: #333 !important; }
.why-matters {
background: #fff8e6;
border-left: 4px solid #f59e0b;
padding: 1rem;
border-radius: 0 8px 8px 0;
color: #333 !important;
}
.why-matters p { color: #333 !important; }
.evolution {
background: linear-gradient(135deg, #e8f5e9, #c8e6c9);
border-left: 4px solid #43a047;
padding: 1rem;
border-radius: 0 8px 8px 0;
color: #1a1a1a !important;
}
.evolution p { color: #1a1a1a !important; }
.evolution strong { color: #2e7d32 !important; }
.evolution em { color: #1a1a1a !important; font-style: italic; }
.slider-panel {
background: linear-gradient(145deg, #f0f0ff, #fff);
border: 2px solid #667eea;
border-radius: 12px;
padding: 1rem;
}
.topic-btn { transition: all 0.2s ease; }
.topic-btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
}
.update-btn {
background: linear-gradient(135deg, #43e97b, #38f9d7) !important;
font-weight: 600;
font-size: 1rem;
}
"""
with gr.Blocks(title="AI Learning Playground") as app:
gr.HTML("""
<div class="header">
<h1>AI Learning Playground</h1>
<p>See AI thinking. Interactive visualizations from basics to cutting edge.</p>
</div>
""")
api_connected = generator is not None
if not api_connected:
with gr.Row():
api_key_input = gr.Textbox(label="Together AI API Key", type="password", scale=5)
connect_btn = gr.Button("Connect", variant="primary", scale=1)
api_status = gr.Markdown("")
with gr.Column(visible=api_connected) as main_interface:
# Input
with gr.Row():
topic_input = gr.Textbox(
label="What concept do you want to understand?",
placeholder="Try: Gradient Descent, K-Means Clustering, Attention Mechanism, Neural Network...",
scale=5,
)
generate_btn = gr.Button("Visualize", variant="primary", scale=1)
# Quick topics - diverse concepts showcasing app's power
gr.Markdown("**Explore concepts:**")
with gr.Row():
topics = [
"Transformer", # LLM/NLP flagship
"CNN", # Computer Vision
"GAN", # Generative AI
"LSTM", # Sequence modeling
"Backpropagation", # ML fundamental
]
topic_btns = [gr.Button(t, size="sm", elem_classes="topic-btn") for t in topics]
# Title and one-liner
title_output = gr.Markdown()
oneliner_output = gr.Markdown(elem_classes="oneliner")
# Multi-view visualizations
gr.Markdown("### Multiple Views - Same Concept")
with gr.Row():
plot1 = gr.Plot(label="View 1")
plot2 = gr.Plot(label="View 2")
plot3 = gr.Plot(label="View 3")
# Parameter controls
with gr.Group(visible=False, elem_classes="slider-panel") as slider_group:
gr.Markdown("### Adjust Parameters - See What Changes")
with gr.Row():
slider1 = gr.Slider(minimum=1, maximum=10, value=3, step=1, visible=False)
slider2 = gr.Slider(minimum=1, maximum=10, value=5, step=1, visible=False)
with gr.Row():
slider3 = gr.Slider(minimum=1, maximum=10, value=5, step=1, visible=False)
slider4 = gr.Slider(minimum=1, maximum=10, value=5, step=1, visible=False)
update_btn = gr.Button("Update All Visualizations", elem_classes="update-btn", visible=False)
# Explanations
with gr.Row():
with gr.Column():
gr.Markdown("### The Intuition")
intuition_output = gr.Markdown(elem_classes="intuition")
with gr.Column():
gr.Markdown("### Why It Matters")
why_output = gr.Markdown(elem_classes="why-matters")
# Evolution - How this concept improved on the past
with gr.Row():
with gr.Column():
gr.Markdown("### Evolution: What Problem Does This Solve?")
evolution_output = gr.Markdown(elem_classes="evolution")
# Math
with gr.Accordion("Mathematical Details", open=True):
math_output = gr.Markdown(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False}
]
)
# Event Handlers
if not api_connected:
def manual_connect(key):
global generator
if not key.strip():
return "Enter API key", gr.update(visible=False)
try:
generator = VisualizationGenerator(api_key=key)
return "Connected!", gr.update(visible=True)
except Exception as e:
return f"Error: {e}", gr.update(visible=False)
connect_btn.click(manual_connect, [api_key_input], [api_status, main_interface])
all_outputs = [
plot1, plot2, plot3,
title_output, oneliner_output, intuition_output, why_output, evolution_output, math_output,
slider_group, slider1, slider2, slider3, slider4, update_btn
]
generate_btn.click(generate_visualization, [topic_input], all_outputs)
topic_input.submit(generate_visualization, [topic_input], all_outputs)
for btn, topic in zip(topic_btns, topics):
btn.click(lambda t=topic: t, outputs=[topic_input]).then(
generate_visualization, [topic_input], all_outputs
)
sliders = [slider1, slider2, slider3, slider4]
update_btn.click(apply_params, sliders, [plot1, plot2, plot3])
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860, theme=app_theme, css=custom_css)