Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import subprocess | |
| import os | |
| import re | |
| import shutil | |
| from pathlib import Path | |
| import tempfile | |
| # Try to import google.generativeai | |
| try: | |
| import google.generativeai as genai | |
| GENAI_AVAILABLE = True | |
| except ImportError: | |
| GENAI_AVAILABLE = False | |
| print("") | |
| # Get API key from environment | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") | |
| def fix_all_known_errors(code): | |
| """Fix ALL known Manim errors""" | |
| # 1. Fix colors to base 10 only | |
| color_map = { | |
| 'ORANGE_A': 'ORANGE', 'ORANGE_B': 'ORANGE', 'ORANGE_C': 'ORANGE', 'ORANGE_D': 'ORANGE', 'ORANGE_E': 'ORANGE', | |
| 'GRAY_A': 'WHITE', 'GRAY_B': 'WHITE', 'GRAY_C': 'WHITE', 'GRAY_D': 'WHITE', 'GRAY_E': 'WHITE', | |
| 'GREY': 'WHITE', 'GRAY': 'WHITE', 'GREY_A': 'WHITE', 'GREY_B': 'WHITE', | |
| 'PURPLE_A': 'PURPLE', 'PURPLE_B': 'PURPLE', 'PURPLE_C': 'PURPLE', 'PURPLE_D': 'PURPLE', 'PURPLE_E': 'PURPLE', | |
| 'RED_A': 'RED', 'RED_B': 'RED', 'RED_C': 'RED', 'RED_D': 'RED', 'RED_E': 'RED', | |
| 'BLUE_A': 'BLUE', 'BLUE_B': 'BLUE', 'BLUE_C': 'BLUE', 'BLUE_D': 'BLUE', 'BLUE_E': 'BLUE', | |
| 'GREEN_A': 'GREEN', 'GREEN_B': 'GREEN', 'GREEN_C': 'GREEN', 'GREEN_D': 'GREEN', 'GREEN_E': 'GREEN', | |
| 'YELLOW_A': 'YELLOW', 'YELLOW_B': 'YELLOW', 'YELLOW_C': 'YELLOW', 'YELLOW_D': 'YELLOW', 'YELLOW_E': 'YELLOW', | |
| 'BROWN': 'ORANGE', 'GOLD': 'YELLOW', 'SILVER': 'WHITE', 'BRONZE': 'ORANGE', | |
| 'MAROON': 'RED', 'NAVY': 'BLUE', 'CYAN': 'TEAL', 'MAGENTA': 'PINK', 'LIME': 'GREEN', | |
| } | |
| for old, new in color_map.items(): | |
| code = re.sub(rf'\b{old}\b', new, code) | |
| # 2. Fix invalid objects | |
| code = re.sub(r'Checkmark\([^)]*\)', 'Circle(radius=0.3, color=GREEN, fill_opacity=1)', code) | |
| code = re.sub(r'CheckMark\([^)]*\)', 'Circle(radius=0.3, color=GREEN, fill_opacity=1)', code) | |
| code = re.sub(r'XMark\([^)]*\)', 'Cross(color=RED).scale(0.5)', code) | |
| code = re.sub(r'SVGMobject\([^)]+\)', 'Circle(radius=0.8, color=PINK, fill_opacity=0.8)', code) | |
| code = re.sub(r'ImageMobject\([^)]+\)', 'Square(side_length=2, color=BLUE, fill_opacity=0.5)', code) | |
| # 3. Fix DecimalNumber methods | |
| code = re.sub(r'\.add_prefix\([^)]+\)', '', code) | |
| code = re.sub(r'\.add_suffix\([^)]+\)', '', code) | |
| # 4. Fix axis labels - replace with Text | |
| code = re.sub(r'(\w+)\s*=\s*axes\.get_x_axis_label\([^)]+\)', | |
| r'\1 = Text("X", font_size=24).next_to(axes.x_axis, DOWN, buff=0.3)', code) | |
| code = re.sub(r'(\w+)\s*=\s*axes\.get_y_axis_label\([^)]+\)', | |
| r'\1 = Text("Y", font_size=24).next_to(axes.y_axis, LEFT, buff=0.3).rotate(90*DEGREES)', code) | |
| code = re.sub(r'\.get_x_axis_label\([^)]+\)', '', code) | |
| code = re.sub(r'\.get_y_axis_label\([^)]+\)', '', code) | |
| # 5. Limit font sizes | |
| code = re.sub(r'font_size=(\d+)', lambda m: f'font_size={min(int(m.group(1)), 36)}', code) | |
| # 6. Fix shifts to safe values | |
| code = re.sub(r'\.shift\(UP\s*\*\s*\d+\.?\d*\)', '.shift(UP*1.5)', code) | |
| code = re.sub(r'\.shift\(DOWN\s*\*\s*\d+\.?\d*\)', '.shift(DOWN*1.5)', code) | |
| code = re.sub(r'\.shift\(LEFT\s*\*\s*\d+\.?\d*\)', '.shift(LEFT*3)', code) | |
| code = re.sub(r'\.shift\(RIGHT\s*\*\s*\d+\.?\d*\)', '.shift(RIGHT*3)', code) | |
| # 7. Force buff on all to_edge | |
| code = re.sub(r'\.to_edge\((UP|DOWN|LEFT|RIGHT)\)(?!\s*,)', r'.to_edge(\1, buff=1)', code) | |
| # 8. Replace Tex with MathTex | |
| code = re.sub(r'\bTex\(', 'MathTex(', code) | |
| return code | |
| def generate_code_with_gemini(prompt): | |
| """Generate Manim code using Gemini API""" | |
| if not GENAI_AVAILABLE: | |
| return None, "❌ google-generativeai package not installed" | |
| if not GEMINI_API_KEY: | |
| return None, "❌ GEMINI_API_KEY not set in environment variables" | |
| try: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| models = [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods] | |
| model_name = next((m for m in models if 'flash' in m.lower()), models[0]) | |
| model = genai.GenerativeModel(model_name) | |
| full_prompt = f"""Generate Manim code. CRITICAL RULES: | |
| 1. COLORS - ONLY these 10: BLUE, RED, GREEN, YELLOW, ORANGE, PURPLE, PINK, TEAL, WHITE, BLACK | |
| NO variants (_A, _B, _C, _D, _E) | |
| 2. TEXT OVERLAP PREVENTION (MOST IMPORTANT): | |
| - ALWAYS FadeOut text before showing new text | |
| - NEVER reuse the same position without FadeOut first | |
| - Pattern for EVERY text: | |
| ```python | |
| title = Text("Title", font_size=36).to_edge(UP, buff=1) | |
| self.play(Write(title)) | |
| self.wait(1.5) | |
| self.play(FadeOut(title)) # ← MANDATORY! | |
| # Now safe to reuse UP position | |
| subtitle = Text("Next", font_size=32).to_edge(UP, buff=1) | |
| self.play(Write(subtitle)) | |
| self.wait(1.5) | |
| self.play(FadeOut(subtitle)) # ← MANDATORY! | |
| ``` | |
| 3. SAFE BOUNDARIES (CRITICAL): | |
| - Font size: MAX 36 (never larger!) | |
| - Shifts: UP*1.5, DOWN*1.5, LEFT*3, RIGHT*3 (MAX!) | |
| - Always use buff=1 with to_edge() | |
| - Safe zone: X[-4, 4], Y[-2, 2] | |
| 4. NEVER use: | |
| - get_x_axis_label(), get_y_axis_label() | |
| - Checkmark, SVGMobject, ImageMobject | |
| - .add_prefix(), .add_suffix() | |
| 5. For axis labels: | |
| x_label = Text("Time", font_size=24).next_to(axes.x_axis, DOWN, buff=0.5) | |
| y_label = Text("Value", font_size=24).next_to(axes.y_axis, LEFT, buff=0.5) | |
| 6. STRUCTURE - Follow this pattern: | |
| ```python | |
| from manim import * | |
| class MyScene(Scene): | |
| def construct(self): | |
| self.camera.background_color = WHITE | |
| # Section 1 | |
| text1 = Text("First", font_size=36).to_edge(UP, buff=1) | |
| self.play(Write(text1)) | |
| self.wait(1.5) | |
| self.play(FadeOut(text1)) # ← Clean up! | |
| # Section 2 | |
| text2 = Text("Second", font_size=32).to_edge(UP, buff=1) | |
| self.play(Write(text2)) | |
| self.wait(1.5) | |
| self.play(FadeOut(text2)) # ← Clean up! | |
| ``` | |
| User wants: {prompt} | |
| Generate complete code following ALL rules above. EVERY text MUST have FadeOut before next text!""" | |
| response = model.generate_content(full_prompt) | |
| code = response.text | |
| return code, None | |
| except Exception as e: | |
| return None, f"❌ Gemini API error: {str(e)}" | |
| def render_video(prompt, quality="low"): | |
| """Main function to generate and render Manim video""" | |
| if not prompt or not prompt.strip(): | |
| yield "❌ Please enter a prompt!", None, None | |
| return | |
| try: | |
| # Create temp directory | |
| temp_dir = Path(tempfile.mkdtemp(prefix="manim_")) | |
| yield "🤖 Generating code with Gemini AI...", None, None | |
| # Generate code | |
| code, error = generate_code_with_gemini(prompt) | |
| if error: | |
| yield error, None, None | |
| return | |
| if not code: | |
| yield "❌ Failed to generate code", None, None | |
| return | |
| # Extract Python code if wrapped in markdown | |
| if "```python" in code: | |
| code = code.split("```python")[1].split("```")[0] | |
| elif "```" in code: | |
| code = code.split("```")[1].split("```")[0] | |
| code = code.strip() | |
| # Apply fixes | |
| code = fix_all_known_errors(code) | |
| # Find Scene class | |
| match = re.search(r'class\s+(\w+)\s*\(Scene\)', code) | |
| if not match: | |
| yield "❌ No Scene class found in generated code!", code, None | |
| return | |
| class_name = match.group(1) | |
| # Save code to file | |
| code_file = temp_dir / "animation.py" | |
| with open(code_file, 'w', encoding='utf-8') as f: | |
| f.write(code) | |
| yield f"✓ Code generated (Scene: {class_name})\n🎬 Rendering video (this may take 1-2 minutes)...", code, None | |
| # Render video with absolute paths | |
| quality_map = {'low': '-ql', 'medium': '-qm', 'high': '-qh'} | |
| quality_flag = quality_map.get(quality, '-ql') | |
| abs_code_file = str(code_file.absolute()) | |
| media_dir = str((temp_dir / "media").absolute()) | |
| command = f"manim {abs_code_file} {class_name} {quality_flag} --disable_caching --media_dir {media_dir}" | |
| process = subprocess.Popen( | |
| command, shell=True, | |
| stdout=subprocess.PIPE, stderr=subprocess.STDOUT, | |
| text=True | |
| ) | |
| output_lines = [] | |
| for line in process.stdout: | |
| output_lines.append(line) | |
| process.wait() | |
| if process.returncode == 0: | |
| # Find the generated video | |
| media_path = temp_dir / "media" | |
| video_files = list(media_path.rglob("*.mp4")) | |
| if video_files: | |
| video_path = str(video_files[0]) | |
| yield f"✅ Video rendered successfully! 🎉", code, video_path | |
| else: | |
| yield "❌ Video file not found after rendering", code, None | |
| else: | |
| error_msg = ''.join(output_lines[-30:]) # Last 30 lines | |
| yield f"❌ Rendering failed:\n\n{error_msg}", code, None | |
| # Cleanup | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except: | |
| pass | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| yield f"❌ Error: {str(e)}\n\nDetails:\n{error_details}", None, None | |
| # Gradio Interface | |
| def create_interface(): | |
| with gr.Blocks(title="Manim Video Generator") as demo: | |
| gr.Markdown(""" | |
| # 🎬 Manim Video Generator | |
| Create mathematical animations using AI! Describe what you want to see animated. | |
| **Examples:** | |
| - "Explain Pythagorean theorem with animation" | |
| - "Show a sine wave transforming into a cosine wave" | |
| - "Animate a circle morphing into a square" | |
| - "Show the concept of derivatives with a tangent line" | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="What animation do you want?", | |
| placeholder="e.g., Show a circle morphing into a square", | |
| lines=4 | |
| ) | |
| quality_input = gr.Radio( | |
| choices=["low", "medium", "high"], | |
| value="low", | |
| label="Quality (low is faster, recommended)" | |
| ) | |
| generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg") | |
| with gr.Column(): | |
| status_output = gr.Textbox(label="Status", lines=4) | |
| video_output = gr.Video(label="Generated Animation") | |
| code_output = gr.Code(label="Generated Manim Code", language="python", lines=15) | |
| gr.Markdown(""" | |
| ### 💡 Tips: | |
| - Be specific in your description | |
| - Start with simple prompts to test | |
| - Low quality renders faster (30 seconds to 1 minute) | |
| - Medium/High quality may take 2-3 minutes | |
| - Complex animations may timeout on free tier | |
| """) | |
| generate_btn.click( | |
| fn=render_video, | |
| inputs=[prompt_input, quality_input], | |
| outputs=[status_output, code_output, video_output] | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Show a blue circle morphing into a red square", "low"], | |
| ["Animate the Pythagorean theorem: a² + b² = c²", "low"], | |
| ["Show a sine wave moving across the screen", "low"], | |
| ["Create an animation showing 3 dots forming a triangle", "low"], | |
| ["Show a number counting from 0 to 10", "low"], | |
| ], | |
| inputs=[prompt_input, quality_input] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() |