Spaces:
Running
Running
| """ | |
| Example usage of PaperBanana framework. | |
| This script demonstrates how to use PaperBanana to generate academic illustrations. | |
| """ | |
| import os | |
| from paperbanana import PaperBanana, generate_illustration | |
| # Example methodology from a hypothetical paper | |
| EXAMPLE_METHODOLOGY = """ | |
| Our proposed method consists of three main stages: | |
| 1. Feature Extraction: We use a pretrained ResNet-50 backbone to extract visual features | |
| from input images. The features are pooled using adaptive average pooling to obtain | |
| a fixed-size representation. | |
| 2. Attention Mechanism: We apply multi-head self-attention to capture long-range | |
| dependencies between different spatial regions. The attention module has 8 heads | |
| and uses scaled dot-product attention. | |
| 3. Classification Head: The attended features are passed through a two-layer MLP | |
| with ReLU activation and dropout (p=0.5) for final classification. The output | |
| layer uses softmax activation. | |
| The entire model is trained end-to-end using cross-entropy loss with the Adam optimizer. | |
| """ | |
| EXAMPLE_CAPTION = "Architecture of our proposed attention-based image classification model" | |
| # Example reference set (normally would be loaded from a database) | |
| EXAMPLE_REFERENCE_SET = [ | |
| { | |
| 'id': 'ref_001', | |
| 'domain': 'Computer Vision', | |
| 'diagram_type': 'Architecture Diagram', | |
| 'description': 'CNN architecture with attention modules showing feature extraction, attention layers, and classification head' | |
| }, | |
| { | |
| 'id': 'ref_002', | |
| 'domain': 'Computer Vision', | |
| 'diagram_type': 'Pipeline Diagram', | |
| 'description': 'Image processing pipeline from input through multiple stages to output' | |
| }, | |
| { | |
| 'id': 'ref_003', | |
| 'domain': 'Natural Language Processing', | |
| 'diagram_type': 'Architecture Diagram', | |
| 'description': 'Transformer architecture with self-attention mechanism' | |
| }, | |
| ] | |
| def example_basic_usage(): | |
| """Example 1: Basic usage with default settings.""" | |
| print("\n" + "="*80) | |
| print("EXAMPLE 1: Basic Usage") | |
| print("="*80 + "\n") | |
| result = generate_illustration( | |
| methodology_text=EXAMPLE_METHODOLOGY, | |
| caption=EXAMPLE_CAPTION, | |
| output_path="examples/basic_example" | |
| ) | |
| print(f"\nGenerated image: {result['final_image_path']}") | |
| print(f"Iterations: {result['iterations']}") | |
| def example_with_references(): | |
| """Example 2: Using reference examples.""" | |
| print("\n" + "="*80) | |
| print("EXAMPLE 2: With Reference Examples") | |
| print("="*80 + "\n") | |
| result = generate_illustration( | |
| methodology_text=EXAMPLE_METHODOLOGY, | |
| caption=EXAMPLE_CAPTION, | |
| reference_set=EXAMPLE_REFERENCE_SET, | |
| output_path="examples/with_references" | |
| ) | |
| print(f"\nGenerated image: {result['final_image_path']}") | |
| def example_ablation_study(): | |
| """Example 3: Ablation study - testing without certain components.""" | |
| print("\n" + "="*80) | |
| print("EXAMPLE 3: Ablation Study") | |
| print("="*80 + "\n") | |
| # Without styling | |
| print("\n--- Without Stylist Agent ---") | |
| result1 = generate_illustration( | |
| methodology_text=EXAMPLE_METHODOLOGY, | |
| caption=EXAMPLE_CAPTION, | |
| output_path="examples/ablation_no_style", | |
| skip_styling=True | |
| ) | |
| # Without refinement | |
| print("\n--- Without Iterative Refinement ---") | |
| result2 = generate_illustration( | |
| methodology_text=EXAMPLE_METHODOLOGY, | |
| caption=EXAMPLE_CAPTION, | |
| output_path="examples/ablation_no_refinement", | |
| skip_refinement=True | |
| ) | |
| def example_statistical_plot(): | |
| """Example 4: Generating statistical plots.""" | |
| print("\n" + "="*80) | |
| print("EXAMPLE 4: Statistical Plot Generation") | |
| print("="*80 + "\n") | |
| plot_description = """ | |
| Create a line plot comparing accuracy across training epochs for three models: | |
| - Baseline CNN (blue line) | |
| - Our method without attention (orange line) | |
| - Our full method (green line) | |
| X-axis: Training Epochs (0-100) | |
| Y-axis: Validation Accuracy (%) | |
| The baseline should plateau around 85%, method without attention around 88%, | |
| and full method should reach 92%. | |
| """ | |
| # Example data (normally would come from actual experiments) | |
| plot_data = { | |
| 'epochs': list(range(0, 101, 10)), | |
| 'baseline': [60, 70, 75, 78, 80, 82, 83, 84, 85, 85, 85], | |
| 'no_attention': [65, 75, 80, 83, 85, 86, 87, 87.5, 88, 88, 88], | |
| 'full_method': [70, 80, 85, 87, 89, 90, 91, 91.5, 92, 92, 92] | |
| } | |
| pb = PaperBanana(mode="plot") | |
| result = pb.generate( | |
| methodology_text=plot_description, | |
| caption="Comparison of validation accuracy across training epochs", | |
| output_path="examples/accuracy_plot", | |
| data=plot_data | |
| ) | |
| print(f"\nGenerated plot code: {result['final_image_path']}") | |
| print("Run the generated Python file to create the plot image.") | |
| def example_with_neurips_references(): | |
| """Example 5b: Using MinerU-parsed NeurIPS reference set.""" | |
| print("\n" + "="*80) | |
| print("EXAMPLE 5b: With NeurIPS 2025 Reference Set (from MinerU)") | |
| print("="*80 + "\n") | |
| from load_reference_set import load_reference_set | |
| ref_set = load_reference_set() | |
| if not ref_set: | |
| print("No reference set found. Ensure data/spotlight_reference_set.json exists.") | |
| return | |
| result = generate_illustration( | |
| methodology_text=EXAMPLE_METHODOLOGY, | |
| caption=EXAMPLE_CAPTION, | |
| reference_set=ref_set, | |
| output_path="examples/neurips_refs" | |
| ) | |
| print(f"\nGenerated image: {result['final_image_path']}") | |
| def example_full_pipeline(): | |
| """Example 6: Full pipeline with all features and history saving.""" | |
| print("\n" + "="*80) | |
| print("EXAMPLE 6: Full Pipeline with History") | |
| print("="*80 + "\n") | |
| pb = PaperBanana( | |
| reference_set=EXAMPLE_REFERENCE_SET, | |
| mode="diagram", | |
| max_iterations=3 | |
| ) | |
| result = pb.generate( | |
| methodology_text=EXAMPLE_METHODOLOGY, | |
| caption=EXAMPLE_CAPTION, | |
| output_path="examples/full_pipeline" | |
| ) | |
| # Save generation history for analysis | |
| pb.save_history("examples/generation_history.json") | |
| print(f"\nFinal image: {result['final_image_path']}") | |
| print(f"Description versions: {len(result['history']['descriptions'])}") | |
| print(f"Critiques performed: {len(result['history']['critiques'])}") | |
| def main(): | |
| """Run all examples.""" | |
| # Create examples directory | |
| os.makedirs("examples", exist_ok=True) | |
| print("\n" + "="*80) | |
| print("PaperBanana Examples") | |
| print("="*80) | |
| print("\nThese examples demonstrate various features of the PaperBanana framework.") | |
| print("Make sure you have set the GEMINI_API_KEY environment variable.\n") | |
| # Check for API key | |
| if not os.environ.get("GEMINI_API_KEY"): | |
| print("ERROR: GEMINI_API_KEY environment variable not set!") | |
| print("Please set it with: export GEMINI_API_KEY='your-api-key'") | |
| return | |
| # Run examples (comment out any you don't want to run) | |
| try: | |
| # Example 1: Basic usage | |
| example_basic_usage() | |
| # Example 2: With references | |
| # example_with_references() | |
| # Example 3: Ablation study | |
| # example_ablation_study() | |
| # Example 4: Statistical plots | |
| # example_statistical_plot() | |
| # Example 5b: With NeurIPS MinerU references | |
| # example_with_neurips_references() | |
| # Example 6: Full pipeline | |
| # example_full_pipeline() | |
| except Exception as e: | |
| print(f"\nError during execution: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("\n" + "="*80) | |
| print("Examples Complete!") | |
| print("="*80 + "\n") | |
| if __name__ == "__main__": | |
| main() | |