Spaces:
Sleeping
Sleeping
| # Copyright 2026 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE- | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| CLI entry point for PaperBanana Skill. | |
| Generates publication-quality academic diagrams and plots from method text. | |
| """ | |
| import argparse | |
| import asyncio | |
| import base64 | |
| import shutil | |
| import sys | |
| from io import BytesIO | |
| from pathlib import Path | |
| # Ensure project root is on sys.path | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| def ensure_model_config(): | |
| """Copy model_config.template.yaml to model_config.yaml if missing.""" | |
| configs_dir = PROJECT_ROOT / "configs" | |
| config_path = configs_dir / "model_config.yaml" | |
| template_path = configs_dir / "model_config.template.yaml" | |
| if not config_path.exists() and template_path.exists(): | |
| shutil.copy2(template_path, config_path) | |
| def ensure_dataset(task_name: str): | |
| """Download PaperBananaBench data from HuggingFace if not present locally.""" | |
| data_dir = PROJECT_ROOT / "data" / "PaperBananaBench" / task_name | |
| ref_path = data_dir / "ref.json" | |
| images_dir = data_dir / "images" | |
| if ref_path.exists() and images_dir.exists(): | |
| return | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except ImportError: | |
| print("ERROR: huggingface_hub is required for automatic dataset download.\n" | |
| "Install it with: pip install huggingface_hub", file=sys.stderr) | |
| sys.exit(1) | |
| print(f"Downloading PaperBananaBench/{task_name} from HuggingFace...") | |
| snapshot_download( | |
| "dwzhu/PaperBananaBench", | |
| repo_type="dataset", | |
| allow_patterns=[f"{task_name}/*"], | |
| local_dir=str(PROJECT_ROOT / "data" / "PaperBananaBench"), | |
| ) | |
| def extract_final_image_b64(result: dict, exp_mode: str) -> str | None: | |
| """Return the base64-encoded final image from a pipeline result dict. | |
| Follows the same fallback order as demo.py:display_candidate_result. | |
| """ | |
| task_name = "diagram" | |
| # Try critic rounds 3 → 0 | |
| for round_idx in range(3, -1, -1): | |
| key = f"target_{task_name}_critic_desc{round_idx}_base64_jpg" | |
| if key in result and result[key]: | |
| return result[key] | |
| # Fallback: stylist (demo_full) or planner | |
| if exp_mode == "demo_full": | |
| key = f"target_{task_name}_stylist_desc0_base64_jpg" | |
| else: | |
| key = f"target_{task_name}_desc0_base64_jpg" | |
| return result.get(key) | |
| async def run(args): | |
| ensure_model_config() | |
| ensure_dataset(args.task) | |
| # Late imports so env is ready | |
| from agents.planner_agent import PlannerAgent | |
| from agents.visualizer_agent import VisualizerAgent | |
| from agents.stylist_agent import StylistAgent | |
| from agents.critic_agent import CriticAgent | |
| from agents.retriever_agent import RetrieverAgent | |
| from agents.vanilla_agent import VanillaAgent | |
| from agents.polish_agent import PolishAgent | |
| from utils import config | |
| from utils.paperviz_processor import PaperVizProcessor | |
| # Read content from file if --content-file is given | |
| content = args.content | |
| if args.content_file: | |
| content = Path(args.content_file).read_text(encoding="utf-8") | |
| if not content: | |
| print("ERROR: --content or --content-file is required.", file=sys.stderr) | |
| sys.exit(1) | |
| exp_mode = args.exp_mode | |
| exp_config = config.ExpConfig( | |
| dataset_name="Demo", | |
| split_name="demo", | |
| exp_mode=exp_mode, | |
| retrieval_setting=args.retrieval_setting, | |
| main_model_name=args.main_model_name, | |
| image_gen_model_name=args.image_gen_model_name, | |
| work_dir=PROJECT_ROOT, | |
| ) | |
| processor = PaperVizProcessor( | |
| exp_config=exp_config, | |
| vanilla_agent=VanillaAgent(exp_config=exp_config), | |
| planner_agent=PlannerAgent(exp_config=exp_config), | |
| visualizer_agent=VisualizerAgent(exp_config=exp_config), | |
| stylist_agent=StylistAgent(exp_config=exp_config), | |
| critic_agent=CriticAgent(exp_config=exp_config), | |
| retriever_agent=RetrieverAgent(exp_config=exp_config), | |
| polish_agent=PolishAgent(exp_config=exp_config), | |
| ) | |
| num_candidates = args.num_candidates | |
| # Build data dicts | |
| data_list = [] | |
| for i in range(num_candidates): | |
| data_list.append({ | |
| "filename": f"skill_candidate_{i}", | |
| "caption": args.caption, | |
| "content": content, | |
| "visual_intent": args.caption, | |
| "additional_info": {"rounded_ratio": args.aspect_ratio}, | |
| "max_critic_rounds": args.max_critic_rounds, | |
| }) | |
| # Process (parallel when multiple candidates) | |
| results = [] | |
| async for result_data in processor.process_queries_batch( | |
| data_list, max_concurrent=num_candidates, do_eval=False | |
| ): | |
| results.append(result_data) | |
| if not results: | |
| print("ERROR: Pipeline returned no results.", file=sys.stderr) | |
| sys.exit(1) | |
| # Save images | |
| from PIL import Image | |
| output_path = Path(args.output).resolve() | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| saved_paths = [] | |
| for idx, result in enumerate(results): | |
| b64 = extract_final_image_b64(result, exp_mode) | |
| if not b64: | |
| print(f"WARNING: No image produced for candidate {idx}.", file=sys.stderr) | |
| continue | |
| if "," in b64: | |
| b64 = b64.split(",")[1] | |
| image_data = base64.b64decode(b64) | |
| img = Image.open(BytesIO(image_data)) | |
| if num_candidates == 1: | |
| save_path = output_path | |
| else: | |
| stem = output_path.stem | |
| suffix = output_path.suffix or ".png" | |
| save_path = output_path.parent / f"{stem}_{idx}{suffix}" | |
| img.save(str(save_path), format="PNG") | |
| saved_paths.append(str(save_path)) | |
| for p in saved_paths: | |
| print(p) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="PaperBanana Skill: generate academic diagrams/plots from text" | |
| ) | |
| parser.add_argument("--content", type=str, default="", | |
| help="Method section text to visualize") | |
| parser.add_argument("--content-file", type=str, default="", | |
| help="Path to a file containing the method section text") | |
| parser.add_argument("--caption", type=str, required=True, | |
| help="Figure caption / visual intent") | |
| parser.add_argument("--task", type=str, default="diagram", | |
| choices=["diagram", "plot"], | |
| help="Task type: diagram or plot") | |
| parser.add_argument("--output", type=str, default="output.png", | |
| help="Output image path (default: output.png)") | |
| parser.add_argument("--aspect-ratio", type=str, default="21:9", | |
| choices=["21:9", "16:9", "3:2"], | |
| help="Aspect ratio (default: 21:9)") | |
| parser.add_argument("--max-critic-rounds", type=int, default=3, | |
| help="Max critic refinement rounds (default: 3)") | |
| parser.add_argument("--num-candidates", type=int, default=10, | |
| help="Number of parallel candidates to generate (default: 10)") | |
| parser.add_argument("--retrieval-setting", type=str, default="auto", | |
| choices=["auto", "manual", "random", "none"], | |
| help="Retrieval mode: auto (VLM selects refs), manual, random, or none (default: auto)") | |
| parser.add_argument("--main-model-name", type=str, default="", | |
| help="Main model name for VLM agents (default: from config, currently gemini-3.1-pro-preview)") | |
| parser.add_argument("--image-gen-model-name", type=str, default="", | |
| help="Model name for image generation (default: from config, currently gemini-3.1-flash-image-preview)") | |
| parser.add_argument("--exp-mode", type=str, default="demo_full", | |
| choices=["demo_full", "demo_planner_critic"], | |
| help="Pipeline mode: demo_full (Retriever+Planner+Stylist+Visualizer+Critic) or demo_planner_critic (Retriever+Planner+Visualizer+Critic, no Stylist) (default: demo_full)") | |
| args = parser.parse_args() | |
| asyncio.run(run(args)) | |
| if __name__ == "__main__": | |
| main() | |