PaperBanana / utils /config.py
dwzhu
Initial deployment: Gradio app + PaperBananaBench data
587f33e
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration for experiments
"""
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
@dataclass
class ExpConfig:
"""Experiment configuration"""
dataset_name: Literal["PaperBananaBench"]
task_name: Literal["diagram", "plot"] = "diagram"
split_name: str = "test"
temperature: float = 1.0
exp_mode: str = ""
retrieval_setting: Literal["auto", "manual", "random", "none"] = "auto"
max_critic_rounds: int = 3
main_model_name: str = ""
image_gen_model_name: str = ""
work_dir: Path = Path(__file__).parent.parent
timestamp: str | None = None
def __post_init__(self):
os.environ["TZ"] = "America/Los_Angeles" # set the timezone as you like
if hasattr(time, "tzset"):
time.tzset() # Only available on Unix; no-op guard for Windows
# Fallback to yaml config if no model name provided
if not self.main_model_name or not self.image_gen_model_name:
import yaml
config_path = self.work_dir / "configs" / "model_config.yaml"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
model_config_data = yaml.safe_load(f) or {}
if not self.main_model_name:
self.main_model_name = model_config_data.get("defaults", {}).get("main_model_name", "")
if not self.image_gen_model_name:
self.image_gen_model_name = model_config_data.get("defaults", {}).get("image_gen_model_name", "")
# Fallback to environment variables
if not self.main_model_name:
self.main_model_name = os.environ.get("MAIN_MODEL_NAME", "")
if not self.image_gen_model_name:
self.image_gen_model_name = os.environ.get("IMAGE_GEN_MODEL_NAME", "")
# Hard defaults so model name is never empty
if not self.main_model_name:
self.main_model_name = "gemini-3.1-pro-preview"
print(f"Warning: main_model_name not configured, falling back to '{self.main_model_name}'. "
"Set it in configs/model_config.yaml or via --main-model-name.")
if not self.image_gen_model_name:
self.image_gen_model_name = "gemini-3.1-flash-image-preview"
print(f"Warning: image_gen_model_name not configured, falling back to '{self.image_gen_model_name}'. "
"Set it in configs/model_config.yaml or via --image-gen-model-name.")
self.timestamp = (
time.strftime("%m%d_%H%M") if self.timestamp is None else self.timestamp
)
self.exp_name = f"{self.timestamp}_{self.retrieval_setting}ret_{self.exp_mode}_{self.split_name}"
# mkdir result_dir if not exists
self.result_dir = self.work_dir / "results" / f"{self.dataset_name}_{self.task_name}"
self.result_dir.mkdir(exist_ok=True, parents=True)