Spaces:
Running
Running
File size: 7,910 Bytes
572d3da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
"""
Example usage of PaperBanana framework.
This script demonstrates how to use PaperBanana to generate academic illustrations.
"""
import os
from paperbanana import PaperBanana, generate_illustration
# Example methodology from a hypothetical paper
EXAMPLE_METHODOLOGY = """
Our proposed method consists of three main stages:
1. Feature Extraction: We use a pretrained ResNet-50 backbone to extract visual features
from input images. The features are pooled using adaptive average pooling to obtain
a fixed-size representation.
2. Attention Mechanism: We apply multi-head self-attention to capture long-range
dependencies between different spatial regions. The attention module has 8 heads
and uses scaled dot-product attention.
3. Classification Head: The attended features are passed through a two-layer MLP
with ReLU activation and dropout (p=0.5) for final classification. The output
layer uses softmax activation.
The entire model is trained end-to-end using cross-entropy loss with the Adam optimizer.
"""
EXAMPLE_CAPTION = "Architecture of our proposed attention-based image classification model"
# Example reference set (normally would be loaded from a database)
EXAMPLE_REFERENCE_SET = [
{
'id': 'ref_001',
'domain': 'Computer Vision',
'diagram_type': 'Architecture Diagram',
'description': 'CNN architecture with attention modules showing feature extraction, attention layers, and classification head'
},
{
'id': 'ref_002',
'domain': 'Computer Vision',
'diagram_type': 'Pipeline Diagram',
'description': 'Image processing pipeline from input through multiple stages to output'
},
{
'id': 'ref_003',
'domain': 'Natural Language Processing',
'diagram_type': 'Architecture Diagram',
'description': 'Transformer architecture with self-attention mechanism'
},
]
def example_basic_usage():
"""Example 1: Basic usage with default settings."""
print("\n" + "="*80)
print("EXAMPLE 1: Basic Usage")
print("="*80 + "\n")
result = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/basic_example"
)
print(f"\nGenerated image: {result['final_image_path']}")
print(f"Iterations: {result['iterations']}")
def example_with_references():
"""Example 2: Using reference examples."""
print("\n" + "="*80)
print("EXAMPLE 2: With Reference Examples")
print("="*80 + "\n")
result = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
reference_set=EXAMPLE_REFERENCE_SET,
output_path="examples/with_references"
)
print(f"\nGenerated image: {result['final_image_path']}")
def example_ablation_study():
"""Example 3: Ablation study - testing without certain components."""
print("\n" + "="*80)
print("EXAMPLE 3: Ablation Study")
print("="*80 + "\n")
# Without styling
print("\n--- Without Stylist Agent ---")
result1 = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/ablation_no_style",
skip_styling=True
)
# Without refinement
print("\n--- Without Iterative Refinement ---")
result2 = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/ablation_no_refinement",
skip_refinement=True
)
def example_statistical_plot():
"""Example 4: Generating statistical plots."""
print("\n" + "="*80)
print("EXAMPLE 4: Statistical Plot Generation")
print("="*80 + "\n")
plot_description = """
Create a line plot comparing accuracy across training epochs for three models:
- Baseline CNN (blue line)
- Our method without attention (orange line)
- Our full method (green line)
X-axis: Training Epochs (0-100)
Y-axis: Validation Accuracy (%)
The baseline should plateau around 85%, method without attention around 88%,
and full method should reach 92%.
"""
# Example data (normally would come from actual experiments)
plot_data = {
'epochs': list(range(0, 101, 10)),
'baseline': [60, 70, 75, 78, 80, 82, 83, 84, 85, 85, 85],
'no_attention': [65, 75, 80, 83, 85, 86, 87, 87.5, 88, 88, 88],
'full_method': [70, 80, 85, 87, 89, 90, 91, 91.5, 92, 92, 92]
}
pb = PaperBanana(mode="plot")
result = pb.generate(
methodology_text=plot_description,
caption="Comparison of validation accuracy across training epochs",
output_path="examples/accuracy_plot",
data=plot_data
)
print(f"\nGenerated plot code: {result['final_image_path']}")
print("Run the generated Python file to create the plot image.")
def example_with_neurips_references():
"""Example 5b: Using MinerU-parsed NeurIPS reference set."""
print("\n" + "="*80)
print("EXAMPLE 5b: With NeurIPS 2025 Reference Set (from MinerU)")
print("="*80 + "\n")
from load_reference_set import load_reference_set
ref_set = load_reference_set()
if not ref_set:
print("No reference set found. Ensure data/spotlight_reference_set.json exists.")
return
result = generate_illustration(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
reference_set=ref_set,
output_path="examples/neurips_refs"
)
print(f"\nGenerated image: {result['final_image_path']}")
def example_full_pipeline():
"""Example 6: Full pipeline with all features and history saving."""
print("\n" + "="*80)
print("EXAMPLE 6: Full Pipeline with History")
print("="*80 + "\n")
pb = PaperBanana(
reference_set=EXAMPLE_REFERENCE_SET,
mode="diagram",
max_iterations=3
)
result = pb.generate(
methodology_text=EXAMPLE_METHODOLOGY,
caption=EXAMPLE_CAPTION,
output_path="examples/full_pipeline"
)
# Save generation history for analysis
pb.save_history("examples/generation_history.json")
print(f"\nFinal image: {result['final_image_path']}")
print(f"Description versions: {len(result['history']['descriptions'])}")
print(f"Critiques performed: {len(result['history']['critiques'])}")
def main():
"""Run all examples."""
# Create examples directory
os.makedirs("examples", exist_ok=True)
print("\n" + "="*80)
print("PaperBanana Examples")
print("="*80)
print("\nThese examples demonstrate various features of the PaperBanana framework.")
print("Make sure you have set the GEMINI_API_KEY environment variable.\n")
# Check for API key
if not os.environ.get("GEMINI_API_KEY"):
print("ERROR: GEMINI_API_KEY environment variable not set!")
print("Please set it with: export GEMINI_API_KEY='your-api-key'")
return
# Run examples (comment out any you don't want to run)
try:
# Example 1: Basic usage
example_basic_usage()
# Example 2: With references
# example_with_references()
# Example 3: Ablation study
# example_ablation_study()
# Example 4: Statistical plots
# example_statistical_plot()
# Example 5b: With NeurIPS MinerU references
# example_with_neurips_references()
# Example 6: Full pipeline
# example_full_pipeline()
except Exception as e:
print(f"\nError during execution: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*80)
print("Examples Complete!")
print("="*80 + "\n")
if __name__ == "__main__":
main()
|