PaperBanana / examples.py
Samarth0710's picture
Deploy PaperBanana app
572d3da verified
"""
Example usage of PaperBanana framework.
This script demonstrates how to use PaperBanana to generate academic illustrations.
"""
import os
from paperbanana import PaperBanana, generate_illustration
# Example methodology from a hypothetical paper
EXAMPLE_METHODOLOGY = """
Our proposed method consists of three main stages:
1. Feature Extraction: We use a pretrained ResNet-50 backbone to extract visual features
from input images. The features are pooled using adaptive average pooling to obtain
a fixed-size representation.
2. Attention Mechanism: We apply multi-head self-attention to capture long-range
dependencies between different spatial regions. The attention module has 8 heads
and uses scaled dot-product attention.
3. Classification Head: The attended features are passed through a two-layer MLP
with ReLU activation and dropout (p=0.5) for final classification. The output
layer uses softmax activation.
The entire model is trained end-to-end using cross-entropy loss with the Adam optimizer.
"""
EXAMPLE_CAPTION = "Architecture of our proposed attention-based image classification model"
# Example reference set (normally would be loaded from a database)
EXAMPLE_REFERENCE_SET = [
{
'id': 'ref_001',
'domain': 'Computer Vision',
'diagram_type': 'Architecture Diagram',
'description': 'CNN architecture with attention modules showing feature extraction, attention layers, and classification head'
},
{
'id': 'ref_002',
'domain': 'Computer Vision',
'diagram_type': 'Pipeline Diagram',
'description': 'Image processing pipeline from input through multiple stages to output'
},
{
'id': 'ref_003',
'domain': 'Natural Language Processing',
'diagram_type': 'Architecture Diagram',
'description': 'Transformer architecture with self-attention mechanism'
},
]
def example_basic_usage():
"""Example 1: Basic usage with default settings."""
print("\n" + "="*80)
print("EXAMPLE 1: Basic Usage")
print("="*80 + "\n")
result = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/basic_example"
)
print(f"\nGenerated image: {result['final_image_path']}")
print(f"Iterations: {result['iterations']}")
def example_with_references():
"""Example 2: Using reference examples."""
print("\n" + "="*80)
print("EXAMPLE 2: With Reference Examples")
print("="*80 + "\n")
result = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
reference_set=EXAMPLE_REFERENCE_SET,
output_path="examples/with_references"
)
print(f"\nGenerated image: {result['final_image_path']}")
def example_ablation_study():
"""Example 3: Ablation study - testing without certain components."""
print("\n" + "="*80)
print("EXAMPLE 3: Ablation Study")
print("="*80 + "\n")
# Without styling
print("\n--- Without Stylist Agent ---")
result1 = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/ablation_no_style",
skip_styling=True
)
# Without refinement
print("\n--- Without Iterative Refinement ---")
result2 = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/ablation_no_refinement",
skip_refinement=True
)
def example_statistical_plot():
"""Example 4: Generating statistical plots."""
print("\n" + "="*80)
print("EXAMPLE 4: Statistical Plot Generation")
print("="*80 + "\n")
plot_description = """
Create a line plot comparing accuracy across training epochs for three models:
- Baseline CNN (blue line)
- Our method without attention (orange line)
- Our full method (green line)
X-axis: Training Epochs (0-100)
Y-axis: Validation Accuracy (%)
The baseline should plateau around 85%, method without attention around 88%,
and full method should reach 92%.
"""
# Example data (normally would come from actual experiments)
plot_data = {
'epochs': list(range(0, 101, 10)),
'baseline': [60, 70, 75, 78, 80, 82, 83, 84, 85, 85, 85],
'no_attention': [65, 75, 80, 83, 85, 86, 87, 87.5, 88, 88, 88],
'full_method': [70, 80, 85, 87, 89, 90, 91, 91.5, 92, 92, 92]
}
pb = PaperBanana(mode="plot")
result = pb.generate(
methodology_text=plot_description,
caption="Comparison of validation accuracy across training epochs",
output_path="examples/accuracy_plot",
data=plot_data
)
print(f"\nGenerated plot code: {result['final_image_path']}")
print("Run the generated Python file to create the plot image.")
def example_with_neurips_references():
"""Example 5b: Using MinerU-parsed NeurIPS reference set."""
print("\n" + "="*80)
print("EXAMPLE 5b: With NeurIPS 2025 Reference Set (from MinerU)")
print("="*80 + "\n")
from load_reference_set import load_reference_set
ref_set = load_reference_set()
if not ref_set:
print("No reference set found. Ensure data/spotlight_reference_set.json exists.")
return
result = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
reference_set=ref_set,
output_path="examples/neurips_refs"
)
print(f"\nGenerated image: {result['final_image_path']}")
def example_full_pipeline():
"""Example 6: Full pipeline with all features and history saving."""
print("\n" + "="*80)
print("EXAMPLE 6: Full Pipeline with History")
print("="*80 + "\n")
pb = PaperBanana(
reference_set=EXAMPLE_REFERENCE_SET,
mode="diagram",
max_iterations=3
)
result = pb.generate(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/full_pipeline"
)
# Save generation history for analysis
pb.save_history("examples/generation_history.json")
print(f"\nFinal image: {result['final_image_path']}")
print(f"Description versions: {len(result['history']['descriptions'])}")
print(f"Critiques performed: {len(result['history']['critiques'])}")
def main():
"""Run all examples."""
# Create examples directory
os.makedirs("examples", exist_ok=True)
print("\n" + "="*80)
print("PaperBanana Examples")
print("="*80)
print("\nThese examples demonstrate various features of the PaperBanana framework.")
print("Make sure you have set the GEMINI_API_KEY environment variable.\n")
# Check for API key
if not os.environ.get("GEMINI_API_KEY"):
print("ERROR: GEMINI_API_KEY environment variable not set!")
print("Please set it with: export GEMINI_API_KEY='your-api-key'")
return
# Run examples (comment out any you don't want to run)
try:
# Example 1: Basic usage
example_basic_usage()
# Example 2: With references
# example_with_references()
# Example 3: Ablation study
# example_ablation_study()
# Example 4: Statistical plots
# example_statistical_plot()
# Example 5b: With NeurIPS MinerU references
# example_with_neurips_references()
# Example 6: Full pipeline
# example_full_pipeline()
except Exception as e:
print(f"\nError during execution: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*80)
print("Examples Complete!")
print("="*80 + "\n")
if __name__ == "__main__":
main()