File size: 6,398 Bytes
9823a7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | import json
import logging
import os
import warnings
from dataclasses import asdict, dataclass, fields
from pathlib import Path
from typing import ClassVar, Optional, Set
import yaml
from .exceptions import BeakerConfigurationError
DEFAULT_CONFIG_LOCATION: Optional[Path] = None
DEFAULT_INTERNAL_CONFIG_LOCATION: Optional[Path] = None
try:
DEFAULT_CONFIG_LOCATION = Path.home() / ".beaker" / "config.yml"
DEFAULT_INTERNAL_CONFIG_LOCATION = Path.home() / ".beaker" / ".beaker-py.json"
except RuntimeError:
# Can't locate home directory.
pass
__all__ = ["Config"]
logger = logging.getLogger(__name__)
@dataclass
class Config:
user_token: str
"""
Beaker user token that can be obtained from
`beaker.org <https://beaker.org/user>`_.
"""
agent_address: str = "https://beaker.org"
"""
The address of the Beaker server.
"""
rpc_address: str = "beaker.org:443"
"""
The address of the Beaker gRPC server.
"""
default_org: Optional[str] = "ai2"
"""
Default Beaker organization to use.
"""
default_workspace: Optional[str] = None
"""
Default Beaker workspace to use.
"""
default_image: Optional[str] = None
"""
The default image used for interactive sessions.
"""
ADDRESS_KEY: ClassVar[str] = "BEAKER_ADDR"
CONFIG_PATH_KEY: ClassVar[str] = "BEAKER_CONFIG"
TOKEN_KEY: ClassVar[str] = "BEAKER_TOKEN"
IGNORE_FIELDS: ClassVar[Set[str]] = {"updater_timestamp", "updater_message"}
def __post_init__(self):
if self.default_org is not None and (self.default_org) == 0:
self.default_org = None
if self.default_workspace is not None and len(self.default_workspace) == 0:
self.default_workspace = None
if self.default_image is not None and len(self.default_image) == 0:
self.default_image = None
def __str__(self) -> str:
fields_str = "user_token=***, " + ", ".join(
[f"{f.name}={getattr(self, f.name)}" for f in fields(self) if f.name != "user_token"]
)
return f"{self.__class__.__name__}({fields_str})"
@classmethod
def from_env(cls, **overrides) -> "Config":
"""
Initialize a config from environment variables or a local config file if one
can be found.
.. note::
Environment variables take precedence over values in the config file.
"""
config: Config
path = cls.find_config()
if path is not None:
config = cls.from_path(path)
if cls.TOKEN_KEY in os.environ:
config.user_token = os.environ[cls.TOKEN_KEY]
elif cls.TOKEN_KEY in os.environ:
config = cls(
user_token=os.environ[cls.TOKEN_KEY],
)
elif "user_token" in overrides:
config = cls(user_token=overrides["user_token"])
else:
raise BeakerConfigurationError(
f"Failed to find config file or environment variable '{cls.TOKEN_KEY}'"
)
# Override with environment variables.
if cls.ADDRESS_KEY in os.environ:
config.agent_address = os.environ[cls.ADDRESS_KEY]
# Override with any arguments passed to this method.
for name, value in overrides.items():
if hasattr(config, name):
setattr(config, name, value)
else:
raise BeakerConfigurationError(f"Beaker config has no attribute '{name}'")
if not config.user_token:
raise BeakerConfigurationError("Invalid Beaker user token, token is empty")
return config
@classmethod
def from_path(cls, path: Path) -> "Config":
"""
Initialize a config from a local config file.
"""
with open(path) as config_file:
logger.debug("Loading beaker config from '%s'", path)
field_names = {f.name for f in fields(cls)}
data = yaml.load(config_file, Loader=yaml.SafeLoader)
for key in list(data.keys()):
if key in cls.IGNORE_FIELDS:
data.pop(key)
continue
value = data[key]
if key not in field_names:
del data[key]
warnings.warn(
f"Unknown field '{key}' found in config '{path}'. "
f"If this is a bug, please report it at https://github.com/allenai/beaker-py/issues/new/",
RuntimeWarning,
)
elif isinstance(value, str) and value == "":
# Replace empty strings with `None`
data[key] = None
return cls(**data)
def save(self, path: Optional[Path] = None):
"""
Save the config to the given path.
"""
if path is None:
if self.CONFIG_PATH_KEY in os.environ:
path = Path(os.environ[self.CONFIG_PATH_KEY])
elif DEFAULT_CONFIG_LOCATION is not None:
path = DEFAULT_CONFIG_LOCATION
if path is None:
raise ValueError("param 'path' is required")
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as config_file:
yaml.dump(asdict(self), config_file)
@classmethod
def find_config(cls) -> Optional[Path]:
if cls.CONFIG_PATH_KEY in os.environ:
path = Path(os.environ[cls.CONFIG_PATH_KEY])
if path.is_file():
return path
elif DEFAULT_CONFIG_LOCATION is not None and DEFAULT_CONFIG_LOCATION.is_file():
return DEFAULT_CONFIG_LOCATION
return None
@dataclass
class InternalConfig:
version_checked: Optional[float] = None
@classmethod
def load(cls) -> Optional["InternalConfig"]:
path = DEFAULT_INTERNAL_CONFIG_LOCATION
if path is None:
return None
elif path.is_file():
with open(path, "r") as f:
return cls(**json.load(f))
else:
return cls()
def save(self):
path = DEFAULT_INTERNAL_CONFIG_LOCATION
if path is None:
return None
else:
path.parent.mkdir(exist_ok=True, parents=True)
with open(path, "w") as f:
json.dump(asdict(self), f)
|