Spaces:
Sleeping
Sleeping
| # Copyright 2026 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE- | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Configuration for experiments | |
| """ | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Literal | |
| class ExpConfig: | |
| """Experiment configuration""" | |
| dataset_name: Literal["PaperBananaBench"] | |
| task_name: Literal["diagram", "plot"] = "diagram" | |
| split_name: str = "test" | |
| temperature: float = 1.0 | |
| exp_mode: str = "" | |
| retrieval_setting: Literal["auto", "manual", "random", "none"] = "auto" | |
| max_critic_rounds: int = 3 | |
| main_model_name: str = "" | |
| image_gen_model_name: str = "" | |
| work_dir: Path = Path(__file__).parent.parent | |
| timestamp: str | None = None | |
| def __post_init__(self): | |
| os.environ["TZ"] = "America/Los_Angeles" # set the timezone as you like | |
| if hasattr(time, "tzset"): | |
| time.tzset() # Only available on Unix; no-op guard for Windows | |
| # Fallback to yaml config if no model name provided | |
| if not self.main_model_name or not self.image_gen_model_name: | |
| import yaml | |
| config_path = self.work_dir / "configs" / "model_config.yaml" | |
| if config_path.exists(): | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| model_config_data = yaml.safe_load(f) or {} | |
| if not self.main_model_name: | |
| self.main_model_name = model_config_data.get("defaults", {}).get("main_model_name", "") | |
| if not self.image_gen_model_name: | |
| self.image_gen_model_name = model_config_data.get("defaults", {}).get("image_gen_model_name", "") | |
| # Fallback to environment variables | |
| if not self.main_model_name: | |
| self.main_model_name = os.environ.get("MAIN_MODEL_NAME", "") | |
| if not self.image_gen_model_name: | |
| self.image_gen_model_name = os.environ.get("IMAGE_GEN_MODEL_NAME", "") | |
| # Hard defaults so model name is never empty | |
| if not self.main_model_name: | |
| self.main_model_name = "gemini-3.1-pro-preview" | |
| print(f"Warning: main_model_name not configured, falling back to '{self.main_model_name}'. " | |
| "Set it in configs/model_config.yaml or via --main-model-name.") | |
| if not self.image_gen_model_name: | |
| self.image_gen_model_name = "gemini-3.1-flash-image-preview" | |
| print(f"Warning: image_gen_model_name not configured, falling back to '{self.image_gen_model_name}'. " | |
| "Set it in configs/model_config.yaml or via --image-gen-model-name.") | |
| self.timestamp = ( | |
| time.strftime("%m%d_%H%M") if self.timestamp is None else self.timestamp | |
| ) | |
| self.exp_name = f"{self.timestamp}_{self.retrieval_setting}ret_{self.exp_mode}_{self.split_name}" | |
| # mkdir result_dir if not exists | |
| self.result_dir = self.work_dir / "results" / f"{self.dataset_name}_{self.task_name}" | |
| self.result_dir.mkdir(exist_ok=True, parents=True) | |