File size: 5,145 Bytes
587f33e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Main script to launch PaperVizAgent
"""

import asyncio
import json
import argparse
from pathlib import Path
import aiofiles
import numpy as np

from agents.vanilla_agent import VanillaAgent
from agents.planner_agent import PlannerAgent
from agents.visualizer_agent import VisualizerAgent
from agents.stylist_agent import StylistAgent
from agents.critic_agent import CriticAgent
from agents.retriever_agent import RetrieverAgent
from agents.polish_agent import PolishAgent

from utils import config, paperviz_processor


async def main():
    """Main function"""
    # add command line args
    parser = argparse.ArgumentParser(description="PaperVizAgent processing script")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="PaperBananaBench",
        help="name of the dataset to use (default: PaperBananaBench)",
    )
    parser.add_argument(
        "--task_name",
        type=str,
        default="diagram",
        choices=["diagram", "plot"],
        help="task type: diagram or plot (default: diagram)",
    )
    parser.add_argument(
        "--split_name",
        type=str,
        default="test",
        help="split of the dataset to use (default: test)",
    )
    parser.add_argument(
        "--exp_mode",
        type=str,
        default="dev",
        help="name of the experiment to use (default: dev)",
    )
    parser.add_argument(
        "--retrieval_setting",
        type=str,
        default="auto",
        choices=["auto", "manual", "random", "none"],
        help="retrieval setting for planner agent (default: auto)",
    )
    parser.add_argument(
        "--max_critic_rounds",
        type=int,
        default=3,
        help="maximum number of critic rounds (default: 3)",
    )
    parser.add_argument(
        "--main_model_name",
        type=str,
        default="",
        help="main model name to use (default: "")",
    )
    parser.add_argument(
        "--image_gen_model_name",
        type=str,
        default="",
        help="image generation model name to use (default: "")",
    )
    args = parser.parse_args()

    exp_config = config.ExpConfig(
        dataset_name=args.dataset_name,
        task_name=args.task_name,
        split_name=args.split_name,
        exp_mode=args.exp_mode,
        retrieval_setting=args.retrieval_setting,
        max_critic_rounds=args.max_critic_rounds,
        main_model_name=args.main_model_name,
        image_gen_model_name=args.image_gen_model_name,
        work_dir=Path(__file__).parent,
    )
    
    base_path = Path(__file__).parent / "data" / exp_config.dataset_name
    input_filename = base_path / exp_config.task_name / f"{exp_config.split_name}.json"
    output_filename = exp_config.result_dir / f"{exp_config.exp_name}.json"
    
    print(f"Input file: {input_filename}", f"Output file: {output_filename}")
    with open(input_filename, "r", encoding="utf-8") as f:
        data_list = json.load(f)

    # Create processor
    processor = paperviz_processor.PaperVizProcessor(
        exp_config=exp_config,
        vanilla_agent=VanillaAgent(exp_config=exp_config),
        planner_agent=PlannerAgent(exp_config=exp_config),
        visualizer_agent=VisualizerAgent(exp_config=exp_config),
        stylist_agent=StylistAgent(exp_config=exp_config),
        critic_agent=CriticAgent(exp_config=exp_config),
        retriever_agent=RetrieverAgent(exp_config=exp_config),
        polish_agent=PolishAgent(exp_config=exp_config),
    )

    # Batch process documents
    concurrent_num = 10
    print(f"Using max concurrency: {concurrent_num}")
    all_result_list = []

    async def save_results_and_scores(current_results):
        print(f"Incremental saving results (count: {len(current_results)}) to {output_filename}")
        async with aiofiles.open(
            output_filename, "w", encoding="utf-8", errors="surrogateescape"
        ) as f:
            json_string = json.dumps(current_results, ensure_ascii=False, indent=4)
            json_string = json_string.encode("utf-8", "ignore").decode("utf-8")
            await f.write(json_string)

    # Process samples incrementally
    idx = 0
    async for result_data in processor.process_queries_batch(
        data_list, max_concurrent=concurrent_num
    ):
        all_result_list.append(result_data)
        idx += 1
        if idx % 10 == 0:
            await save_results_and_scores(all_result_list)

    # Final save
    await save_results_and_scores(all_result_list)
    print("Processing completed.")


if __name__ == "__main__":
    asyncio.run(main())