workflow-twin / env /runtime_config.py
NDGCodes's picture
fix repo structure for HF
1a692ce
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class RuntimeConfig:
seed: int = 42
difficulty: str = "easy"
use_quantizer: bool = False
quant_mode: str = "full"
quant_every_n_steps: int = 1
embedding_dim: int = 16
quant_bits: int = 3
distortion_lambda: float = 0.2
inner_product_lambda: float = 0.1
@staticmethod
def _to_bool(value: str | None, default: bool = False) -> bool:
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
@classmethod
def from_env(cls) -> "RuntimeConfig":
quant_mode = os.getenv("OPENENV_QUANT_MODE", "full").lower()
if quant_mode not in {"off", "full", "throttle", "status", "hybrid"}:
quant_mode = "full"
quant_every_n_steps = max(1, int(os.getenv("OPENENV_QUANT_EVERY_N_STEPS", 1)))
seed_raw = os.getenv("OPENENV_SEED", os.getenv("ENV_SEED", 42))
return cls(
seed=int(seed_raw),
difficulty=os.getenv("OPENENV_DIFFICULTY", "easy"),
use_quantizer=cls._to_bool(os.getenv("OPENENV_USE_QUANTIZER"), False),
quant_mode=quant_mode,
quant_every_n_steps=quant_every_n_steps,
embedding_dim=max(1, int(os.getenv("OPENENV_EMBEDDING_DIM", 16))),
quant_bits=max(1, int(os.getenv("OPENENV_QUANT_BITS", 3))),
distortion_lambda=float(os.getenv("OPENENV_DISTORTION_LAMBDA", 0.2)),
inner_product_lambda=float(os.getenv("OPENENV_INNER_PRODUCT_LAMBDA", 0.1)),
)
def to_env_kwargs(self) -> dict[str, Any]:
return {
"difficulty": self.difficulty,
"seed": self.seed,
"use_quantizer": self.use_quantizer,
"quant_mode": self.quant_mode,
"quant_every_n_steps": self.quant_every_n_steps,
"embedding_dim": self.embedding_dim,
"quant_bits": self.quant_bits,
"distortion_lambda": self.distortion_lambda,
"inner_product_lambda": self.inner_product_lambda,
}
def as_dict(self) -> dict[str, Any]:
return self.to_env_kwargs()