Spaces:
Sleeping
Sleeping
| # Copyright 2026 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE- | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Processing pipeline of PaperVizAgent | |
| """ | |
| import asyncio | |
| from typing import List, Dict, Any, AsyncGenerator | |
| import numpy as np | |
| from tqdm.asyncio import tqdm | |
| from agents.vanilla_agent import VanillaAgent | |
| from agents.planner_agent import PlannerAgent | |
| from agents.visualizer_agent import VisualizerAgent | |
| from agents.stylist_agent import StylistAgent | |
| from agents.critic_agent import CriticAgent | |
| from agents.retriever_agent import RetrieverAgent | |
| from agents.polish_agent import PolishAgent | |
| from .config import ExpConfig | |
| from .eval_toolkits import get_score_for_image_referenced | |
| class PaperVizProcessor: | |
| """Main class for multimodal document processor""" | |
| def __init__( | |
| self, | |
| exp_config: ExpConfig, | |
| vanilla_agent: VanillaAgent, | |
| planner_agent: PlannerAgent, | |
| visualizer_agent: VisualizerAgent, | |
| stylist_agent: StylistAgent, | |
| critic_agent: CriticAgent, | |
| retriever_agent: RetrieverAgent, | |
| polish_agent: PolishAgent, | |
| ): | |
| self.exp_config = exp_config | |
| self.vanilla_agent = vanilla_agent | |
| self.planner_agent = planner_agent | |
| self.visualizer_agent = visualizer_agent | |
| self.stylist_agent = stylist_agent | |
| self.critic_agent = critic_agent | |
| self.retriever_agent = retriever_agent | |
| self.polish_agent = polish_agent | |
| async def _run_critic_iterations(self, data: Dict[str, Any], task_name: str, max_rounds: int = 3, source: str = "stylist") -> Dict[str, Any]: | |
| """ | |
| Run multi-round critic iteration (up to max_rounds). | |
| Returns the data with critic suggestions and updated eval_image_field. | |
| Args: | |
| data: Input data dictionary | |
| task_name: Name of the task (e.g., "diagram", "plot") | |
| max_rounds: Maximum number of critic iterations | |
| source: Source of the input for round 0 critique ("stylist" or "planner") | |
| """ | |
| # Determine initial fallback image key based on source | |
| if source == "planner": | |
| current_best_image_key = f"target_{task_name}_desc0_base64_jpg" | |
| else: # default to stylist | |
| current_best_image_key = f"target_{task_name}_stylist_desc0_base64_jpg" | |
| for round_idx in range(max_rounds): | |
| data["current_critic_round"] = round_idx | |
| data = await self.critic_agent.process(data, source=source) | |
| critic_suggestions_key = f"target_{task_name}_critic_suggestions{round_idx}" | |
| critic_suggestions = data.get(critic_suggestions_key, "") | |
| if critic_suggestions.strip() == "No changes needed.": | |
| print(f"[Critic Round {round_idx}] No changes needed. Stopping iteration.") | |
| break | |
| data = await self.visualizer_agent.process(data) | |
| # Check if visualization validation succeeded | |
| new_image_key = f"target_{task_name}_critic_desc{round_idx}_base64_jpg" | |
| if new_image_key in data and data[new_image_key]: | |
| current_best_image_key = new_image_key | |
| print(f"[Critic Round {round_idx}] Completed iteration. Visualization SUCCESS.") | |
| else: | |
| print(f"[Critic Round {round_idx}] Visualization FAILED (No valid image). Rolling back to previous best: {current_best_image_key}") | |
| break | |
| data["eval_image_field"] = current_best_image_key | |
| return data | |
| async def process_single_query( | |
| self, data: Dict[str, Any], do_eval=True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Complete processing pipeline for a single query | |
| """ | |
| # print(f"[DEBUG] -> Entered process_single_query for candidate {data.get('candidate_id', 'N/A')}") | |
| exp_mode = self.exp_config.exp_mode | |
| task_name = self.exp_config.task_name.lower() | |
| retrieval_setting = self.exp_config.retrieval_setting | |
| # Skip retriever if results were already populated by process_queries_batch | |
| already_retrieved = "top10_references" in data | |
| if exp_mode == "vanilla": | |
| data = await self.vanilla_agent.process(data) | |
| data["eval_image_field"] = f"vanilla_{task_name}_base64_jpg" | |
| elif exp_mode == "dev_planner": | |
| if not already_retrieved: | |
| data = await self.retriever_agent.process(data, retrieval_setting=retrieval_setting) | |
| data = await self.planner_agent.process(data) | |
| data = await self.visualizer_agent.process(data) | |
| data["eval_image_field"] = f"target_{task_name}_desc0_base64_jpg" | |
| elif exp_mode == "dev_planner_stylist": | |
| if not already_retrieved: | |
| data = await self.retriever_agent.process(data, retrieval_setting=retrieval_setting) | |
| data = await self.planner_agent.process(data) | |
| data = await self.stylist_agent.process(data) | |
| data = await self.visualizer_agent.process(data) | |
| data["eval_image_field"] = f"target_{task_name}_stylist_desc0_base64_jpg" | |
| elif exp_mode in ["dev_planner_critic", "demo_planner_critic"]: | |
| if not already_retrieved: | |
| data = await self.retriever_agent.process(data, retrieval_setting=retrieval_setting) | |
| data = await self.planner_agent.process(data) | |
| data = await self.visualizer_agent.process(data) | |
| # Use max_critic_rounds from data if available, otherwise default to 3 | |
| max_rounds = data.get("max_critic_rounds", 3) | |
| data = await self._run_critic_iterations(data, task_name, max_rounds=max_rounds, source="planner") | |
| if "demo" in exp_mode: do_eval = False | |
| elif exp_mode in ["dev_full", "demo_full"]: | |
| if not already_retrieved: | |
| data = await self.retriever_agent.process(data, retrieval_setting=retrieval_setting) | |
| data = await self.planner_agent.process(data) | |
| data = await self.stylist_agent.process(data) | |
| data = await self.visualizer_agent.process(data) | |
| # Use max_critic_rounds from data (if set) or config | |
| max_rounds = data.get("max_critic_rounds", self.exp_config.max_critic_rounds) | |
| data = await self._run_critic_iterations(data, task_name, max_rounds=max_rounds, source="stylist") | |
| if "demo" in exp_mode: do_eval = False | |
| elif exp_mode == "dev_polish": | |
| data = await self.polish_agent.process(data) | |
| data["eval_image_field"] = f"polished_{task_name}_base64_jpg" | |
| elif exp_mode == "dev_retriever": | |
| data = await self.retriever_agent.process(data) | |
| do_eval = False | |
| else: | |
| raise ValueError(f"Unknown experiment name: {exp_mode}") | |
| if do_eval: | |
| data_with_eval = await self.evaluation_function(data, exp_config=self.exp_config) | |
| return data_with_eval | |
| else: | |
| return data | |
| async def process_queries_batch( | |
| self, | |
| data_list: List[Dict[str, Any]], | |
| max_concurrent: int = 50, | |
| do_eval: bool = True, | |
| ) -> AsyncGenerator[Dict[str, Any], None]: | |
| """ | |
| Batch process queries with concurrency support. | |
| Retriever is run once before parallelization to avoid redundant API calls. | |
| """ | |
| # Run Retriever once and share results across all candidates | |
| exp_mode = self.exp_config.exp_mode | |
| retrieval_setting = self.exp_config.retrieval_setting | |
| needs_retrieval = exp_mode not in ("vanilla", "dev_polish", "dev_retriever") | |
| if needs_retrieval and data_list: | |
| print("[Retriever] Running retrieval once for all candidates...") | |
| first_data = data_list[0] | |
| first_data = await self.retriever_agent.process(first_data, retrieval_setting=retrieval_setting) | |
| retrieval_keys = ("top10_references", "retrieved_examples") | |
| for data in data_list[1:]: | |
| for key in retrieval_keys: | |
| if key in first_data: | |
| data[key] = first_data[key] | |
| print(f"[Retriever] Done. Retrieved {len(first_data.get('top10_references', []))} references.") | |
| semaphore = asyncio.Semaphore(max_concurrent) | |
| async def process_with_semaphore(doc): | |
| async with semaphore: | |
| return await self.process_single_query(doc, do_eval=do_eval) | |
| # Create all tasks | |
| tasks = [] | |
| for data in data_list: | |
| task = asyncio.create_task(process_with_semaphore(data)) | |
| tasks.append(task) | |
| all_result_list = [] | |
| eval_dims = ["faithfulness", "conciseness", "readability", "aesthetics", "overall"] | |
| with tqdm(total=len(tasks), desc="Processing concurrently",ascii=True) as pbar: | |
| # Iterate through completed tasks returned by as_completed | |
| for future in asyncio.as_completed(tasks): | |
| result_data = await future | |
| all_result_list.append(result_data) | |
| postfix_dict = {} | |
| for dim in eval_dims: | |
| winner_key = f"{dim}_outcome" | |
| if winner_key in result_data: | |
| winners = [d.get(winner_key) for d in all_result_list] | |
| total = len(winners) | |
| if total > 0: | |
| h_cnt = winners.count("Human") | |
| m_cnt = winners.count("Model") | |
| t_cnt = winners.count("Tie") + winners.count("Both are good") + winners.count("Both are bad") | |
| h_rate = (h_cnt / total) * 100 | |
| m_rate = (m_cnt / total) * 100 | |
| t_rate = (t_cnt / total) * 100 | |
| display_key = dim[:5].capitalize() | |
| postfix_dict[display_key] = f"{m_rate:.0f}/{t_rate:.0f}/{h_rate:.0f}" | |
| pbar.set_postfix(postfix_dict) | |
| pbar.update(1) | |
| yield result_data | |
| async def evaluation_function( | |
| self, data: Dict[str, Any], exp_config: ExpConfig | |
| ) -> Dict[str, Any]: | |
| """ | |
| Evaluation function - uses referenced setting (GT shown first) | |
| """ | |
| data = await get_score_for_image_referenced( | |
| data, task_name=exp_config.task_name, work_dir=exp_config.work_dir | |
| ) | |
| return data | |