File size: 3,545 Bytes
587f33e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Configuration for experiments
"""

import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal


@dataclass
class ExpConfig:
    """Experiment configuration"""

    dataset_name: Literal["PaperBananaBench"]
    task_name: Literal["diagram", "plot"] = "diagram"
    split_name: str = "test"
    temperature: float = 1.0
    exp_mode: str = ""
    retrieval_setting: Literal["auto", "manual", "random", "none"] = "auto"
    max_critic_rounds: int = 3
    main_model_name: str = ""
    image_gen_model_name: str = ""
    work_dir: Path = Path(__file__).parent.parent

    timestamp: str | None = None

    def __post_init__(self):
        os.environ["TZ"] = "America/Los_Angeles"  # set the timezone as you like
        if hasattr(time, "tzset"):
            time.tzset()  # Only available on Unix; no-op guard for Windows
        
        # Fallback to yaml config if no model name provided
        if not self.main_model_name or not self.image_gen_model_name:
            import yaml
            config_path = self.work_dir / "configs" / "model_config.yaml"
            if config_path.exists():
                with open(config_path, "r", encoding="utf-8") as f:
                    model_config_data = yaml.safe_load(f) or {}
                    if not self.main_model_name:
                        self.main_model_name = model_config_data.get("defaults", {}).get("main_model_name", "")
                    if not self.image_gen_model_name:
                        self.image_gen_model_name = model_config_data.get("defaults", {}).get("image_gen_model_name", "")
        # Fallback to environment variables
        if not self.main_model_name:
            self.main_model_name = os.environ.get("MAIN_MODEL_NAME", "")
        if not self.image_gen_model_name:
            self.image_gen_model_name = os.environ.get("IMAGE_GEN_MODEL_NAME", "")
        # Hard defaults so model name is never empty
        if not self.main_model_name:
            self.main_model_name = "gemini-3.1-pro-preview"
            print(f"Warning: main_model_name not configured, falling back to '{self.main_model_name}'. "
                  "Set it in configs/model_config.yaml or via --main-model-name.")
        if not self.image_gen_model_name:
            self.image_gen_model_name = "gemini-3.1-flash-image-preview"
            print(f"Warning: image_gen_model_name not configured, falling back to '{self.image_gen_model_name}'. "
                  "Set it in configs/model_config.yaml or via --image-gen-model-name.")
        self.timestamp = (
            time.strftime("%m%d_%H%M") if self.timestamp is None else self.timestamp
        )
        self.exp_name = f"{self.timestamp}_{self.retrieval_setting}ret_{self.exp_mode}_{self.split_name}"

        # mkdir result_dir if not exists
        self.result_dir = self.work_dir / "results" / f"{self.dataset_name}_{self.task_name}"
        self.result_dir.mkdir(exist_ok=True, parents=True)