Spaces:
Sleeping
Sleeping
| """ | |
| AI Learning Playground - Interactive AI Concept Visualizations | |
| See AI thinking. From neural networks to transformers. | |
| """ | |
| import gradio as gr | |
| import os | |
| from dotenv import load_dotenv | |
| from generator import VisualizationGenerator | |
| load_dotenv() | |
| generator = None | |
| current_config = None | |
| current_topic = None | |
| API_KEY = os.getenv("TOGETHER_API_KEY", "") | |
| if API_KEY and API_KEY != "your-api-key-here": | |
| try: | |
| generator = VisualizationGenerator(api_key=API_KEY) | |
| print("Connected to Together AI") | |
| except Exception as e: | |
| print(f"Connection failed: {e}") | |
| def format_evolution(evolution_data: dict) -> str: | |
| """Format evolution/history section.""" | |
| if not evolution_data: | |
| return "" | |
| predecessor = evolution_data.get('predecessor', '') | |
| problem = evolution_data.get('predecessor_problem', '') | |
| solution = evolution_data.get('how_it_solves', '') | |
| innovation = evolution_data.get('key_innovation', '') | |
| if not any([predecessor, problem, solution, innovation]): | |
| return "" | |
| parts = [] | |
| if predecessor and predecessor.lower() != 'none': | |
| parts.append(f"**Before:** {predecessor}") | |
| if problem: | |
| parts.append(f"\n**The Problem:** {problem}") | |
| elif predecessor.lower() == 'none': | |
| parts.append("**Foundational Concept**") | |
| if problem: | |
| parts.append(f"\n**Problem Solved:** {problem}") | |
| if solution: | |
| parts.append(f"\n**The Solution:** {solution}") | |
| if innovation: | |
| parts.append(f"\n**Key Innovation:** {innovation}") | |
| return '\n'.join(parts) | |
| def format_math(math_data: dict) -> str: | |
| """Format math with explanations.""" | |
| if not math_data: | |
| return "" | |
| parts = [] | |
| formulas = math_data.get('formulas', []) | |
| if formulas: | |
| parts.append("## Key Equations\n") | |
| for f in formulas: | |
| name = f.get('name', '') | |
| eq = f.get('equation', '') | |
| desc = f.get('description', '') | |
| if name: | |
| parts.append(f"### {name}") | |
| if eq: | |
| parts.append(f"\n{eq}\n") | |
| if desc: | |
| parts.append(f"*{desc}*\n") | |
| formula = math_data.get('formula', '') | |
| if formula and not formulas: | |
| parts.append(f"## Formula\n\n{formula}\n") | |
| variables = math_data.get('variables', []) | |
| if variables: | |
| parts.append("\n## What Each Symbol Means\n") | |
| for var in variables: | |
| symbol = var.get('symbol', '') | |
| # Wrap in $...$ for LaTeX rendering if not already wrapped | |
| if symbol and not symbol.startswith('$'): | |
| symbol = f"${symbol}$" | |
| parts.append(f"- {symbol} — {var.get('meaning', '')}") | |
| return '\n'.join(parts) | |
| def generate_visualization(topic: str): | |
| """Generate multi-view visualization.""" | |
| global current_config, current_topic | |
| empty_result = [ | |
| None, None, None, # 3 plot outputs | |
| "", "", "", "", "", "", # text outputs (title, oneliner, intuition, why_it_matters, evolution, math) | |
| gr.update(visible=False), # slider group | |
| ] + [gr.update(visible=False)] * 4 + [gr.update(visible=False)] # sliders + button | |
| if not generator or not topic.strip(): | |
| return empty_result | |
| try: | |
| result = generator.generate(topic) | |
| current_config = result | |
| current_topic = topic | |
| figures = result.get('figures', []) | |
| fig1 = figures[0] if len(figures) > 0 else None | |
| fig2 = figures[1] if len(figures) > 1 else None | |
| fig3 = figures[2] if len(figures) > 2 else None | |
| # Build slider updates - filter to only valid numeric params | |
| raw_params = result.get('params', []) | |
| params = [] | |
| for p in raw_params: | |
| try: | |
| # Validate that min/max/default are numeric | |
| min_val = p.get('min') | |
| max_val = p.get('max') | |
| default_val = p.get('default') | |
| # Skip params with non-numeric values | |
| if min_val is not None and not isinstance(min_val, (int, float)): | |
| float(min_val) # This will raise if not convertible | |
| if max_val is not None and not isinstance(max_val, (int, float)): | |
| float(max_val) | |
| if default_val is not None and not isinstance(default_val, (int, float)): | |
| float(default_val) | |
| params.append(p) | |
| except (ValueError, TypeError): | |
| print(f"Skipping invalid param: {p}") | |
| continue | |
| slider_updates = [] | |
| for i in range(4): | |
| if i < len(params): | |
| p = params[i] | |
| # Handle None values from LLM (use 'or' to catch both missing and null) | |
| min_val = float(p.get('min') or 1) | |
| max_val = float(p.get('max') or 10) | |
| default_val = float(p.get('default') or min_val) | |
| step_val = float(p.get('step') or 1) | |
| # Clamp default to valid range | |
| default_val = max(min_val, min(max_val, default_val)) | |
| slider_updates.append(gr.update( | |
| visible=True, | |
| label=p.get('label') or p.get('name', 'Parameter'), | |
| minimum=min_val, maximum=max_val, | |
| value=default_val, | |
| step=step_val, | |
| )) | |
| else: | |
| slider_updates.append(gr.update(visible=False)) | |
| return [ | |
| fig1, fig2, fig3, | |
| f"# {result.get('title', topic)}", | |
| result.get('oneliner', ''), | |
| result.get('intuition', ''), | |
| result.get('why_it_matters', ''), | |
| format_evolution(result.get('evolution', {})), | |
| format_math(result.get('math', {})), | |
| gr.update(visible=bool(params)), | |
| ] + slider_updates + [gr.update(visible=bool(params))] | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return [None, None, None, f"# Error: {str(e)}", "", "", "", "", "", gr.update(visible=False)] + [gr.update(visible=False)] * 5 | |
| def apply_params(s1, s2, s3, s4): | |
| """Update all visualizations with new params.""" | |
| global current_config, current_topic | |
| if not current_config or not generator: | |
| return None, None, None | |
| params = current_config.get('params', []) | |
| if not params: | |
| return None, None, None | |
| param_values = {} | |
| slider_vals = [s1, s2, s3, s4] | |
| for i, p in enumerate(params[:4]): | |
| if i < len(slider_vals) and slider_vals[i] is not None: | |
| param_values[p['name']] = slider_vals[i] | |
| try: | |
| figures = generator.update_params(current_topic, param_values) | |
| fig1 = figures[0] if len(figures) > 0 else None | |
| fig2 = figures[1] if len(figures) > 1 else None | |
| fig3 = figures[2] if len(figures) > 2 else None | |
| return fig1, fig2, fig3 | |
| except Exception as e: | |
| print(f"Update error: {e}") | |
| return None, None, None | |
| # Theme | |
| app_theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="purple", neutral_hue="slate") | |
| custom_css = """ | |
| .gradio-container { max-width: 1600px !important; } | |
| .header { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 1.5rem 2rem; | |
| border-radius: 16px; | |
| margin-bottom: 1rem; | |
| text-align: center; | |
| } | |
| .header h1 { color: white; margin: 0; font-size: 2rem; } | |
| .header p { color: rgba(255,255,255,0.9); margin: 0.3rem 0 0 0; font-size: 1.1rem; } | |
| .viz-grid { display: grid; gap: 1rem; } | |
| .card { | |
| background: white; | |
| border-radius: 12px; | |
| padding: 1rem; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.08); | |
| } | |
| .oneliner { | |
| background: linear-gradient(135deg, #667eea, #764ba2); | |
| color: white !important; | |
| padding: 1rem; | |
| border-radius: 10px; | |
| font-size: 1.1rem; | |
| text-align: center; | |
| } | |
| .oneliner p { color: white !important; margin: 0; } | |
| .intuition { | |
| background: #f0f4ff; | |
| border-left: 4px solid #667eea; | |
| padding: 1rem; | |
| border-radius: 0 8px 8px 0; | |
| color: #333 !important; | |
| } | |
| .intuition p { color: #333 !important; } | |
| .why-matters { | |
| background: #fff8e6; | |
| border-left: 4px solid #f59e0b; | |
| padding: 1rem; | |
| border-radius: 0 8px 8px 0; | |
| color: #333 !important; | |
| } | |
| .why-matters p { color: #333 !important; } | |
| .evolution { | |
| background: linear-gradient(135deg, #e8f5e9, #c8e6c9); | |
| border-left: 4px solid #43a047; | |
| padding: 1rem; | |
| border-radius: 0 8px 8px 0; | |
| color: #1a1a1a !important; | |
| } | |
| .evolution p { color: #1a1a1a !important; } | |
| .evolution strong { color: #2e7d32 !important; } | |
| .evolution em { color: #1a1a1a !important; font-style: italic; } | |
| .slider-panel { | |
| background: linear-gradient(145deg, #f0f0ff, #fff); | |
| border: 2px solid #667eea; | |
| border-radius: 12px; | |
| padding: 1rem; | |
| } | |
| .topic-btn { transition: all 0.2s ease; } | |
| .topic-btn:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3); | |
| } | |
| .update-btn { | |
| background: linear-gradient(135deg, #43e97b, #38f9d7) !important; | |
| font-weight: 600; | |
| font-size: 1rem; | |
| } | |
| """ | |
| with gr.Blocks(title="AI Learning Playground") as app: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>AI Learning Playground</h1> | |
| <p>See AI thinking. Interactive visualizations from basics to cutting edge.</p> | |
| </div> | |
| """) | |
| api_connected = generator is not None | |
| if not api_connected: | |
| with gr.Row(): | |
| api_key_input = gr.Textbox(label="Together AI API Key", type="password", scale=5) | |
| connect_btn = gr.Button("Connect", variant="primary", scale=1) | |
| api_status = gr.Markdown("") | |
| with gr.Column(visible=api_connected) as main_interface: | |
| # Input | |
| with gr.Row(): | |
| topic_input = gr.Textbox( | |
| label="What concept do you want to understand?", | |
| placeholder="Try: Gradient Descent, K-Means Clustering, Attention Mechanism, Neural Network...", | |
| scale=5, | |
| ) | |
| generate_btn = gr.Button("Visualize", variant="primary", scale=1) | |
| # Quick topics - diverse concepts showcasing app's power | |
| gr.Markdown("**Explore concepts:**") | |
| with gr.Row(): | |
| topics = [ | |
| "Transformer", # LLM/NLP flagship | |
| "CNN", # Computer Vision | |
| "GAN", # Generative AI | |
| "LSTM", # Sequence modeling | |
| "Backpropagation", # ML fundamental | |
| ] | |
| topic_btns = [gr.Button(t, size="sm", elem_classes="topic-btn") for t in topics] | |
| # Title and one-liner | |
| title_output = gr.Markdown() | |
| oneliner_output = gr.Markdown(elem_classes="oneliner") | |
| # Multi-view visualizations | |
| gr.Markdown("### Multiple Views - Same Concept") | |
| with gr.Row(): | |
| plot1 = gr.Plot(label="View 1") | |
| plot2 = gr.Plot(label="View 2") | |
| plot3 = gr.Plot(label="View 3") | |
| # Parameter controls | |
| with gr.Group(visible=False, elem_classes="slider-panel") as slider_group: | |
| gr.Markdown("### Adjust Parameters - See What Changes") | |
| with gr.Row(): | |
| slider1 = gr.Slider(minimum=1, maximum=10, value=3, step=1, visible=False) | |
| slider2 = gr.Slider(minimum=1, maximum=10, value=5, step=1, visible=False) | |
| with gr.Row(): | |
| slider3 = gr.Slider(minimum=1, maximum=10, value=5, step=1, visible=False) | |
| slider4 = gr.Slider(minimum=1, maximum=10, value=5, step=1, visible=False) | |
| update_btn = gr.Button("Update All Visualizations", elem_classes="update-btn", visible=False) | |
| # Explanations | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### The Intuition") | |
| intuition_output = gr.Markdown(elem_classes="intuition") | |
| with gr.Column(): | |
| gr.Markdown("### Why It Matters") | |
| why_output = gr.Markdown(elem_classes="why-matters") | |
| # Evolution - How this concept improved on the past | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Evolution: What Problem Does This Solve?") | |
| evolution_output = gr.Markdown(elem_classes="evolution") | |
| # Math | |
| with gr.Accordion("Mathematical Details", open=True): | |
| math_output = gr.Markdown( | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False} | |
| ] | |
| ) | |
| # Event Handlers | |
| if not api_connected: | |
| def manual_connect(key): | |
| global generator | |
| if not key.strip(): | |
| return "Enter API key", gr.update(visible=False) | |
| try: | |
| generator = VisualizationGenerator(api_key=key) | |
| return "Connected!", gr.update(visible=True) | |
| except Exception as e: | |
| return f"Error: {e}", gr.update(visible=False) | |
| connect_btn.click(manual_connect, [api_key_input], [api_status, main_interface]) | |
| all_outputs = [ | |
| plot1, plot2, plot3, | |
| title_output, oneliner_output, intuition_output, why_output, evolution_output, math_output, | |
| slider_group, slider1, slider2, slider3, slider4, update_btn | |
| ] | |
| generate_btn.click(generate_visualization, [topic_input], all_outputs) | |
| topic_input.submit(generate_visualization, [topic_input], all_outputs) | |
| for btn, topic in zip(topic_btns, topics): | |
| btn.click(lambda t=topic: t, outputs=[topic_input]).then( | |
| generate_visualization, [topic_input], all_outputs | |
| ) | |
| sliders = [slider1, slider2, slider3, slider4] | |
| update_btn.click(apply_params, sliders, [plot1, plot2, plot3]) | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860, theme=app_theme, css=custom_css) | |