File size: 7,910 Bytes
572d3da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
Example usage of PaperBanana framework.

This script demonstrates how to use PaperBanana to generate academic illustrations.
"""
import os
from paperbanana import PaperBanana, generate_illustration

# Example methodology from a hypothetical paper
EXAMPLE_METHODOLOGY = """
Our proposed method consists of three main stages:

1. Feature Extraction: We use a pretrained ResNet-50 backbone to extract visual features 
   from input images. The features are pooled using adaptive average pooling to obtain 
   a fixed-size representation.

2. Attention Mechanism: We apply multi-head self-attention to capture long-range 
   dependencies between different spatial regions. The attention module has 8 heads 
   and uses scaled dot-product attention.

3. Classification Head: The attended features are passed through a two-layer MLP 
   with ReLU activation and dropout (p=0.5) for final classification. The output 
   layer uses softmax activation.

The entire model is trained end-to-end using cross-entropy loss with the Adam optimizer.
"""

EXAMPLE_CAPTION = "Architecture of our proposed attention-based image classification model"

# Example reference set (normally would be loaded from a database)
EXAMPLE_REFERENCE_SET = [
    {
        'id': 'ref_001',
        'domain': 'Computer Vision',
        'diagram_type': 'Architecture Diagram',
        'description': 'CNN architecture with attention modules showing feature extraction, attention layers, and classification head'
    },
    {
        'id': 'ref_002', 
        'domain': 'Computer Vision',
        'diagram_type': 'Pipeline Diagram',
        'description': 'Image processing pipeline from input through multiple stages to output'
    },
    {
        'id': 'ref_003',
        'domain': 'Natural Language Processing',
        'diagram_type': 'Architecture Diagram',
        'description': 'Transformer architecture with self-attention mechanism'
    },
]


def example_basic_usage():
    """Example 1: Basic usage with default settings."""
    print("\n" + "="*80)
    print("EXAMPLE 1: Basic Usage")
    print("="*80 + "\n")
    
    result = generate_illustration(
        methodology_text=EXAMPLE_METHODOLOGY,
        caption=EXAMPLE_CAPTION,
        output_path="examples/basic_example"
    )
    
    print(f"\nGenerated image: {result['final_image_path']}")
    print(f"Iterations: {result['iterations']}")


def example_with_references():
    """Example 2: Using reference examples."""
    print("\n" + "="*80)
    print("EXAMPLE 2: With Reference Examples")
    print("="*80 + "\n")
    
    result = generate_illustration(
        methodology_text=EXAMPLE_METHODOLOGY,
        caption=EXAMPLE_CAPTION,
        reference_set=EXAMPLE_REFERENCE_SET,
        output_path="examples/with_references"
    )
    
    print(f"\nGenerated image: {result['final_image_path']}")


def example_ablation_study():
    """Example 3: Ablation study - testing without certain components."""
    print("\n" + "="*80)
    print("EXAMPLE 3: Ablation Study")
    print("="*80 + "\n")
    
    # Without styling
    print("\n--- Without Stylist Agent ---")
    result1 = generate_illustration(
        methodology_text=EXAMPLE_METHODOLOGY,
        caption=EXAMPLE_CAPTION,
        output_path="examples/ablation_no_style",
        skip_styling=True
    )
    
    # Without refinement
    print("\n--- Without Iterative Refinement ---")
    result2 = generate_illustration(
        methodology_text=EXAMPLE_METHODOLOGY,
        caption=EXAMPLE_CAPTION,
        output_path="examples/ablation_no_refinement",
        skip_refinement=True
    )


def example_statistical_plot():
    """Example 4: Generating statistical plots."""
    print("\n" + "="*80)
    print("EXAMPLE 4: Statistical Plot Generation")
    print("="*80 + "\n")
    
    plot_description = """
    Create a line plot comparing accuracy across training epochs for three models:
    - Baseline CNN (blue line)
    - Our method without attention (orange line)  
    - Our full method (green line)
    
    X-axis: Training Epochs (0-100)
    Y-axis: Validation Accuracy (%)
    
    The baseline should plateau around 85%, method without attention around 88%,
    and full method should reach 92%.
    """
    
    # Example data (normally would come from actual experiments)
    plot_data = {
        'epochs': list(range(0, 101, 10)),
        'baseline': [60, 70, 75, 78, 80, 82, 83, 84, 85, 85, 85],
        'no_attention': [65, 75, 80, 83, 85, 86, 87, 87.5, 88, 88, 88],
        'full_method': [70, 80, 85, 87, 89, 90, 91, 91.5, 92, 92, 92]
    }
    
    pb = PaperBanana(mode="plot")
    result = pb.generate(
        methodology_text=plot_description,
        caption="Comparison of validation accuracy across training epochs",
        output_path="examples/accuracy_plot",
        data=plot_data
    )
    
    print(f"\nGenerated plot code: {result['final_image_path']}")
    print("Run the generated Python file to create the plot image.")


def example_with_neurips_references():
    """Example 5b: Using MinerU-parsed NeurIPS reference set."""
    print("\n" + "="*80)
    print("EXAMPLE 5b: With NeurIPS 2025 Reference Set (from MinerU)")
    print("="*80 + "\n")

    from load_reference_set import load_reference_set

    ref_set = load_reference_set()
    if not ref_set:
        print("No reference set found. Ensure data/spotlight_reference_set.json exists.")
        return

    result = generate_illustration(
        methodology_text=EXAMPLE_METHODOLOGY,
        caption=EXAMPLE_CAPTION,
        reference_set=ref_set,
        output_path="examples/neurips_refs"
    )
    print(f"\nGenerated image: {result['final_image_path']}")


def example_full_pipeline():
    """Example 6: Full pipeline with all features and history saving."""
    print("\n" + "="*80)
    print("EXAMPLE 6: Full Pipeline with History")
    print("="*80 + "\n")
    
    pb = PaperBanana(
        reference_set=EXAMPLE_REFERENCE_SET,
        mode="diagram",
        max_iterations=3
    )
    
    result = pb.generate(
        methodology_text=EXAMPLE_METHODOLOGY,
        caption=EXAMPLE_CAPTION,
        output_path="examples/full_pipeline"
    )
    
    # Save generation history for analysis
    pb.save_history("examples/generation_history.json")
    
    print(f"\nFinal image: {result['final_image_path']}")
    print(f"Description versions: {len(result['history']['descriptions'])}")
    print(f"Critiques performed: {len(result['history']['critiques'])}")


def main():
    """Run all examples."""
    # Create examples directory
    os.makedirs("examples", exist_ok=True)
    
    print("\n" + "="*80)
    print("PaperBanana Examples")
    print("="*80)
    print("\nThese examples demonstrate various features of the PaperBanana framework.")
    print("Make sure you have set the GEMINI_API_KEY environment variable.\n")
    
    # Check for API key
    if not os.environ.get("GEMINI_API_KEY"):
        print("ERROR: GEMINI_API_KEY environment variable not set!")
        print("Please set it with: export GEMINI_API_KEY='your-api-key'")
        return
    
    # Run examples (comment out any you don't want to run)
    try:
        # Example 1: Basic usage
        example_basic_usage()
        
        # Example 2: With references
        # example_with_references()
        
        # Example 3: Ablation study
        # example_ablation_study()
        
        # Example 4: Statistical plots
        # example_statistical_plot()
        
        # Example 5b: With NeurIPS MinerU references
        # example_with_neurips_references()
        
        # Example 6: Full pipeline
        # example_full_pipeline()
        
    except Exception as e:
        print(f"\nError during execution: {e}")
        import traceback
        traceback.print_exc()
    
    print("\n" + "="*80)
    print("Examples Complete!")
    print("="*80 + "\n")


if __name__ == "__main__":
    main()