PaperBanana / main.py
dwzhu
Initial deployment: Gradio app + PaperBananaBench data
587f33e
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main script to launch PaperVizAgent
"""
import asyncio
import json
import argparse
from pathlib import Path
import aiofiles
import numpy as np
from agents.vanilla_agent import VanillaAgent
from agents.planner_agent import PlannerAgent
from agents.visualizer_agent import VisualizerAgent
from agents.stylist_agent import StylistAgent
from agents.critic_agent import CriticAgent
from agents.retriever_agent import RetrieverAgent
from agents.polish_agent import PolishAgent
from utils import config, paperviz_processor
async def main():
"""Main function"""
# add command line args
parser = argparse.ArgumentParser(description="PaperVizAgent processing script")
parser.add_argument(
"--dataset_name",
type=str,
default="PaperBananaBench",
help="name of the dataset to use (default: PaperBananaBench)",
)
parser.add_argument(
"--task_name",
type=str,
default="diagram",
choices=["diagram", "plot"],
help="task type: diagram or plot (default: diagram)",
)
parser.add_argument(
"--split_name",
type=str,
default="test",
help="split of the dataset to use (default: test)",
)
parser.add_argument(
"--exp_mode",
type=str,
default="dev",
help="name of the experiment to use (default: dev)",
)
parser.add_argument(
"--retrieval_setting",
type=str,
default="auto",
choices=["auto", "manual", "random", "none"],
help="retrieval setting for planner agent (default: auto)",
)
parser.add_argument(
"--max_critic_rounds",
type=int,
default=3,
help="maximum number of critic rounds (default: 3)",
)
parser.add_argument(
"--main_model_name",
type=str,
default="",
help="main model name to use (default: "")",
)
parser.add_argument(
"--image_gen_model_name",
type=str,
default="",
help="image generation model name to use (default: "")",
)
args = parser.parse_args()
exp_config = config.ExpConfig(
dataset_name=args.dataset_name,
task_name=args.task_name,
split_name=args.split_name,
exp_mode=args.exp_mode,
retrieval_setting=args.retrieval_setting,
max_critic_rounds=args.max_critic_rounds,
main_model_name=args.main_model_name,
image_gen_model_name=args.image_gen_model_name,
work_dir=Path(__file__).parent,
)
base_path = Path(__file__).parent / "data" / exp_config.dataset_name
input_filename = base_path / exp_config.task_name / f"{exp_config.split_name}.json"
output_filename = exp_config.result_dir / f"{exp_config.exp_name}.json"
print(f"Input file: {input_filename}", f"Output file: {output_filename}")
with open(input_filename, "r", encoding="utf-8") as f:
data_list = json.load(f)
# Create processor
processor = paperviz_processor.PaperVizProcessor(
exp_config=exp_config,
vanilla_agent=VanillaAgent(exp_config=exp_config),
planner_agent=PlannerAgent(exp_config=exp_config),
visualizer_agent=VisualizerAgent(exp_config=exp_config),
stylist_agent=StylistAgent(exp_config=exp_config),
critic_agent=CriticAgent(exp_config=exp_config),
retriever_agent=RetrieverAgent(exp_config=exp_config),
polish_agent=PolishAgent(exp_config=exp_config),
)
# Batch process documents
concurrent_num = 10
print(f"Using max concurrency: {concurrent_num}")
all_result_list = []
async def save_results_and_scores(current_results):
print(f"Incremental saving results (count: {len(current_results)}) to {output_filename}")
async with aiofiles.open(
output_filename, "w", encoding="utf-8", errors="surrogateescape"
) as f:
json_string = json.dumps(current_results, ensure_ascii=False, indent=4)
json_string = json_string.encode("utf-8", "ignore").decode("utf-8")
await f.write(json_string)
# Process samples incrementally
idx = 0
async for result_data in processor.process_queries_batch(
data_list, max_concurrent=concurrent_num
):
all_result_list.append(result_data)
idx += 1
if idx % 10 == 0:
await save_results_and_scores(all_result_list)
# Final save
await save_results_and_scores(all_result_list)
print("Processing completed.")
if __name__ == "__main__":
asyncio.run(main())