Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Literal | |
| from .llm_client import ( | |
| DEFAULT_GEMINI_DM_MODEL, | |
| DEFAULT_GEMINI_HERO_MODEL, | |
| DEFAULT_HF_DM_MODEL, | |
| DEFAULT_HF_HERO_MODEL, | |
| GeminiStructuredClient, | |
| HuggingFaceStructuredClient, | |
| PROVIDER_GEMINI, | |
| PROVIDER_HF_LOCAL, | |
| StructuredModelClient, | |
| ) | |
| StructuredProvider = Literal["gemini", "hf_local"] | |
| InterfaceProvider = Literal["strict", "simple", "gemini"] | |
| InterfaceTranslationMode = Literal["none", "corporate_app"] | |
| RoleName = Literal["dm", "hero"] | |
| DEFAULT_INTERFACE_PROVIDER: InterfaceProvider = "strict" | |
| DEFAULT_INTERFACE_MODEL = "gemini-2.5-flash-lite" | |
| DEFAULT_INTERFACE_TRANSLATION_MODE: InterfaceTranslationMode = "none" | |
| class StructuredClientConfig: | |
| role: RoleName | |
| provider: StructuredProvider | |
| model_name: str | |
| adapter_path: str | None = None | |
| cache_dir: str | None = None | |
| load_in_4bit: bool = True | |
| trust_remote_code: bool = False | |
| class InterfaceConfig: | |
| provider: InterfaceProvider | |
| model_name: str = DEFAULT_INTERFACE_MODEL | |
| narrate_observations: bool = False | |
| translation_mode: InterfaceTranslationMode = DEFAULT_INTERFACE_TRANSLATION_MODE | |
| def resolve_structured_client_config( | |
| role: RoleName, | |
| *, | |
| provider: StructuredProvider | None = None, | |
| model_name: str | None = None, | |
| adapter_path: str | None = None, | |
| ) -> StructuredClientConfig: | |
| env_prefix = f"DND_{role.upper()}" | |
| resolved_provider = provider or _structured_provider_from_env(os.getenv(f"{env_prefix}_PROVIDER")) or PROVIDER_GEMINI | |
| if resolved_provider == PROVIDER_HF_LOCAL: | |
| default_model = DEFAULT_HF_DM_MODEL if role == "dm" else DEFAULT_HF_HERO_MODEL | |
| else: | |
| default_model = DEFAULT_GEMINI_DM_MODEL if role == "dm" else DEFAULT_GEMINI_HERO_MODEL | |
| return StructuredClientConfig( | |
| role=role, | |
| provider=resolved_provider, | |
| model_name=model_name or os.getenv(f"{env_prefix}_MODEL") or default_model, | |
| adapter_path=adapter_path or os.getenv(f"{env_prefix}_ADAPTER_PATH"), | |
| cache_dir=os.getenv("HF_HOME"), | |
| load_in_4bit=_env_bool("DND_LOAD_IN_4BIT", default=True), | |
| trust_remote_code=_env_bool("DND_TRUST_REMOTE_CODE", default=False), | |
| ) | |
| def create_structured_client(config: StructuredClientConfig) -> StructuredModelClient: | |
| if config.provider == PROVIDER_GEMINI: | |
| return GeminiStructuredClient() | |
| if config.provider == PROVIDER_HF_LOCAL: | |
| return HuggingFaceStructuredClient( | |
| adapter_path=config.adapter_path, | |
| cache_dir=config.cache_dir, | |
| load_in_4bit=config.load_in_4bit, | |
| trust_remote_code=config.trust_remote_code, | |
| ) | |
| raise ValueError(f"Unsupported structured provider: {config.provider}") | |
| def resolve_interface_config( | |
| *, | |
| provider: InterfaceProvider | None = None, | |
| model_name: str | None = None, | |
| narrate_observations: bool | None = None, | |
| translation_mode: InterfaceTranslationMode | None = None, | |
| ) -> InterfaceConfig: | |
| resolved_translation = ( | |
| translation_mode | |
| or _interface_translation_mode_from_env(os.getenv("DND_INTERFACE_TRANSLATION_MODE")) | |
| or DEFAULT_INTERFACE_TRANSLATION_MODE | |
| ) | |
| resolved_provider = provider or _interface_provider_from_env(os.getenv("DND_INTERFACE_PROVIDER")) | |
| if resolved_provider is None: | |
| resolved_provider = "gemini" if resolved_translation != "none" else DEFAULT_INTERFACE_PROVIDER | |
| resolved_narrate = narrate_observations | |
| if resolved_narrate is None: | |
| resolved_narrate = _env_bool("DND_INTERFACE_NARRATE", default=False) | |
| if resolved_translation != "none" and resolved_provider != "gemini": | |
| raise ValueError("Interface translation mode requires the Gemini interface provider.") | |
| return InterfaceConfig( | |
| provider=resolved_provider, | |
| model_name=model_name or os.getenv("DND_INTERFACE_MODEL") or DEFAULT_INTERFACE_MODEL, | |
| narrate_observations=resolved_narrate, | |
| translation_mode=resolved_translation, | |
| ) | |
| def build_interface_adapter(config: InterfaceConfig): | |
| from agents.master.interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter, StrictCliInterfaceAdapter | |
| if config.provider == "strict": | |
| return StrictCliInterfaceAdapter() | |
| if config.provider == "simple": | |
| return SimpleInterfaceAdapter() | |
| if config.provider == "gemini": | |
| return GeminiInterfaceAdapter( | |
| model=config.model_name, | |
| narrate_observations=config.narrate_observations, | |
| translation_mode=config.translation_mode, | |
| ) | |
| raise ValueError(f"Unsupported interface provider: {config.provider}") | |
| def _structured_provider_from_env(value: str | None) -> StructuredProvider | None: | |
| if value is None: | |
| return None | |
| normalized = value.strip().lower() | |
| if normalized not in {PROVIDER_GEMINI, PROVIDER_HF_LOCAL}: | |
| raise ValueError(f"Unsupported structured provider value: {value}") | |
| return normalized # type: ignore[return-value] | |
| def _interface_provider_from_env(value: str | None) -> InterfaceProvider | None: | |
| if value is None: | |
| return None | |
| normalized = value.strip().lower() | |
| if normalized not in {"strict", "simple", "gemini"}: | |
| raise ValueError(f"Unsupported interface provider value: {value}") | |
| return normalized # type: ignore[return-value] | |
| def _interface_translation_mode_from_env(value: str | None) -> InterfaceTranslationMode | None: | |
| if value is None: | |
| return None | |
| normalized = value.strip().lower() | |
| if normalized not in {"none", "corporate_app"}: | |
| raise ValueError(f"Unsupported interface translation mode value: {value}") | |
| return normalized # type: ignore[return-value] | |
| def _env_bool(name: str, *, default: bool) -> bool: | |
| raw = os.getenv(name) | |
| if raw is None: | |
| return default | |
| normalized = raw.strip().lower() | |
| if normalized in {"1", "true", "yes", "on"}: | |
| return True | |
| if normalized in {"0", "false", "no", "off"}: | |
| return False | |
| raise ValueError(f"Environment variable {name} must be a boolean value, got {raw!r}") | |