from __future__ import annotations import os from dataclasses import dataclass from typing import Literal from .llm_client import ( DEFAULT_GEMINI_DM_MODEL, DEFAULT_GEMINI_HERO_MODEL, DEFAULT_HF_DM_MODEL, DEFAULT_HF_HERO_MODEL, GeminiStructuredClient, HuggingFaceStructuredClient, PROVIDER_GEMINI, PROVIDER_HF_LOCAL, StructuredModelClient, ) StructuredProvider = Literal["gemini", "hf_local"] InterfaceProvider = Literal["strict", "simple", "gemini"] InterfaceTranslationMode = Literal["none", "corporate_app"] RoleName = Literal["dm", "hero"] DEFAULT_INTERFACE_PROVIDER: InterfaceProvider = "strict" DEFAULT_INTERFACE_MODEL = "gemini-2.5-flash-lite" DEFAULT_INTERFACE_TRANSLATION_MODE: InterfaceTranslationMode = "none" @dataclass(frozen=True) class StructuredClientConfig: role: RoleName provider: StructuredProvider model_name: str adapter_path: str | None = None cache_dir: str | None = None load_in_4bit: bool = True trust_remote_code: bool = False @dataclass(frozen=True) class InterfaceConfig: provider: InterfaceProvider model_name: str = DEFAULT_INTERFACE_MODEL narrate_observations: bool = False translation_mode: InterfaceTranslationMode = DEFAULT_INTERFACE_TRANSLATION_MODE def resolve_structured_client_config( role: RoleName, *, provider: StructuredProvider | None = None, model_name: str | None = None, adapter_path: str | None = None, ) -> StructuredClientConfig: env_prefix = f"DND_{role.upper()}" resolved_provider = provider or _structured_provider_from_env(os.getenv(f"{env_prefix}_PROVIDER")) or PROVIDER_GEMINI if resolved_provider == PROVIDER_HF_LOCAL: default_model = DEFAULT_HF_DM_MODEL if role == "dm" else DEFAULT_HF_HERO_MODEL else: default_model = DEFAULT_GEMINI_DM_MODEL if role == "dm" else DEFAULT_GEMINI_HERO_MODEL return StructuredClientConfig( role=role, provider=resolved_provider, model_name=model_name or os.getenv(f"{env_prefix}_MODEL") or default_model, adapter_path=adapter_path or os.getenv(f"{env_prefix}_ADAPTER_PATH"), cache_dir=os.getenv("HF_HOME"), load_in_4bit=_env_bool("DND_LOAD_IN_4BIT", default=True), trust_remote_code=_env_bool("DND_TRUST_REMOTE_CODE", default=False), ) def create_structured_client(config: StructuredClientConfig) -> StructuredModelClient: if config.provider == PROVIDER_GEMINI: return GeminiStructuredClient() if config.provider == PROVIDER_HF_LOCAL: return HuggingFaceStructuredClient( adapter_path=config.adapter_path, cache_dir=config.cache_dir, load_in_4bit=config.load_in_4bit, trust_remote_code=config.trust_remote_code, ) raise ValueError(f"Unsupported structured provider: {config.provider}") def resolve_interface_config( *, provider: InterfaceProvider | None = None, model_name: str | None = None, narrate_observations: bool | None = None, translation_mode: InterfaceTranslationMode | None = None, ) -> InterfaceConfig: resolved_translation = ( translation_mode or _interface_translation_mode_from_env(os.getenv("DND_INTERFACE_TRANSLATION_MODE")) or DEFAULT_INTERFACE_TRANSLATION_MODE ) resolved_provider = provider or _interface_provider_from_env(os.getenv("DND_INTERFACE_PROVIDER")) if resolved_provider is None: resolved_provider = "gemini" if resolved_translation != "none" else DEFAULT_INTERFACE_PROVIDER resolved_narrate = narrate_observations if resolved_narrate is None: resolved_narrate = _env_bool("DND_INTERFACE_NARRATE", default=False) if resolved_translation != "none" and resolved_provider != "gemini": raise ValueError("Interface translation mode requires the Gemini interface provider.") return InterfaceConfig( provider=resolved_provider, model_name=model_name or os.getenv("DND_INTERFACE_MODEL") or DEFAULT_INTERFACE_MODEL, narrate_observations=resolved_narrate, translation_mode=resolved_translation, ) def build_interface_adapter(config: InterfaceConfig): from agents.master.interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter, StrictCliInterfaceAdapter if config.provider == "strict": return StrictCliInterfaceAdapter() if config.provider == "simple": return SimpleInterfaceAdapter() if config.provider == "gemini": return GeminiInterfaceAdapter( model=config.model_name, narrate_observations=config.narrate_observations, translation_mode=config.translation_mode, ) raise ValueError(f"Unsupported interface provider: {config.provider}") def _structured_provider_from_env(value: str | None) -> StructuredProvider | None: if value is None: return None normalized = value.strip().lower() if normalized not in {PROVIDER_GEMINI, PROVIDER_HF_LOCAL}: raise ValueError(f"Unsupported structured provider value: {value}") return normalized # type: ignore[return-value] def _interface_provider_from_env(value: str | None) -> InterfaceProvider | None: if value is None: return None normalized = value.strip().lower() if normalized not in {"strict", "simple", "gemini"}: raise ValueError(f"Unsupported interface provider value: {value}") return normalized # type: ignore[return-value] def _interface_translation_mode_from_env(value: str | None) -> InterfaceTranslationMode | None: if value is None: return None normalized = value.strip().lower() if normalized not in {"none", "corporate_app"}: raise ValueError(f"Unsupported interface translation mode value: {value}") return normalized # type: ignore[return-value] def _env_bool(name: str, *, default: bool) -> bool: raw = os.getenv(name) if raw is None: return default normalized = raw.strip().lower() if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False raise ValueError(f"Environment variable {name} must be a boolean value, got {raw!r}")