Spaces:
Runtime error
Runtime error
File size: 6,235 Bytes
2803d7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Literal
from .llm_client import (
DEFAULT_GEMINI_DM_MODEL,
DEFAULT_GEMINI_HERO_MODEL,
DEFAULT_HF_DM_MODEL,
DEFAULT_HF_HERO_MODEL,
GeminiStructuredClient,
HuggingFaceStructuredClient,
PROVIDER_GEMINI,
PROVIDER_HF_LOCAL,
StructuredModelClient,
)
StructuredProvider = Literal["gemini", "hf_local"]
InterfaceProvider = Literal["strict", "simple", "gemini"]
InterfaceTranslationMode = Literal["none", "corporate_app"]
RoleName = Literal["dm", "hero"]
DEFAULT_INTERFACE_PROVIDER: InterfaceProvider = "strict"
DEFAULT_INTERFACE_MODEL = "gemini-2.5-flash-lite"
DEFAULT_INTERFACE_TRANSLATION_MODE: InterfaceTranslationMode = "none"
@dataclass(frozen=True)
class StructuredClientConfig:
role: RoleName
provider: StructuredProvider
model_name: str
adapter_path: str | None = None
cache_dir: str | None = None
load_in_4bit: bool = True
trust_remote_code: bool = False
@dataclass(frozen=True)
class InterfaceConfig:
provider: InterfaceProvider
model_name: str = DEFAULT_INTERFACE_MODEL
narrate_observations: bool = False
translation_mode: InterfaceTranslationMode = DEFAULT_INTERFACE_TRANSLATION_MODE
def resolve_structured_client_config(
role: RoleName,
*,
provider: StructuredProvider | None = None,
model_name: str | None = None,
adapter_path: str | None = None,
) -> StructuredClientConfig:
env_prefix = f"DND_{role.upper()}"
resolved_provider = provider or _structured_provider_from_env(os.getenv(f"{env_prefix}_PROVIDER")) or PROVIDER_GEMINI
if resolved_provider == PROVIDER_HF_LOCAL:
default_model = DEFAULT_HF_DM_MODEL if role == "dm" else DEFAULT_HF_HERO_MODEL
else:
default_model = DEFAULT_GEMINI_DM_MODEL if role == "dm" else DEFAULT_GEMINI_HERO_MODEL
return StructuredClientConfig(
role=role,
provider=resolved_provider,
model_name=model_name or os.getenv(f"{env_prefix}_MODEL") or default_model,
adapter_path=adapter_path or os.getenv(f"{env_prefix}_ADAPTER_PATH"),
cache_dir=os.getenv("HF_HOME"),
load_in_4bit=_env_bool("DND_LOAD_IN_4BIT", default=True),
trust_remote_code=_env_bool("DND_TRUST_REMOTE_CODE", default=False),
)
def create_structured_client(config: StructuredClientConfig) -> StructuredModelClient:
if config.provider == PROVIDER_GEMINI:
return GeminiStructuredClient()
if config.provider == PROVIDER_HF_LOCAL:
return HuggingFaceStructuredClient(
adapter_path=config.adapter_path,
cache_dir=config.cache_dir,
load_in_4bit=config.load_in_4bit,
trust_remote_code=config.trust_remote_code,
)
raise ValueError(f"Unsupported structured provider: {config.provider}")
def resolve_interface_config(
*,
provider: InterfaceProvider | None = None,
model_name: str | None = None,
narrate_observations: bool | None = None,
translation_mode: InterfaceTranslationMode | None = None,
) -> InterfaceConfig:
resolved_translation = (
translation_mode
or _interface_translation_mode_from_env(os.getenv("DND_INTERFACE_TRANSLATION_MODE"))
or DEFAULT_INTERFACE_TRANSLATION_MODE
)
resolved_provider = provider or _interface_provider_from_env(os.getenv("DND_INTERFACE_PROVIDER"))
if resolved_provider is None:
resolved_provider = "gemini" if resolved_translation != "none" else DEFAULT_INTERFACE_PROVIDER
resolved_narrate = narrate_observations
if resolved_narrate is None:
resolved_narrate = _env_bool("DND_INTERFACE_NARRATE", default=False)
if resolved_translation != "none" and resolved_provider != "gemini":
raise ValueError("Interface translation mode requires the Gemini interface provider.")
return InterfaceConfig(
provider=resolved_provider,
model_name=model_name or os.getenv("DND_INTERFACE_MODEL") or DEFAULT_INTERFACE_MODEL,
narrate_observations=resolved_narrate,
translation_mode=resolved_translation,
)
def build_interface_adapter(config: InterfaceConfig):
from agents.master.interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter, StrictCliInterfaceAdapter
if config.provider == "strict":
return StrictCliInterfaceAdapter()
if config.provider == "simple":
return SimpleInterfaceAdapter()
if config.provider == "gemini":
return GeminiInterfaceAdapter(
model=config.model_name,
narrate_observations=config.narrate_observations,
translation_mode=config.translation_mode,
)
raise ValueError(f"Unsupported interface provider: {config.provider}")
def _structured_provider_from_env(value: str | None) -> StructuredProvider | None:
if value is None:
return None
normalized = value.strip().lower()
if normalized not in {PROVIDER_GEMINI, PROVIDER_HF_LOCAL}:
raise ValueError(f"Unsupported structured provider value: {value}")
return normalized # type: ignore[return-value]
def _interface_provider_from_env(value: str | None) -> InterfaceProvider | None:
if value is None:
return None
normalized = value.strip().lower()
if normalized not in {"strict", "simple", "gemini"}:
raise ValueError(f"Unsupported interface provider value: {value}")
return normalized # type: ignore[return-value]
def _interface_translation_mode_from_env(value: str | None) -> InterfaceTranslationMode | None:
if value is None:
return None
normalized = value.strip().lower()
if normalized not in {"none", "corporate_app"}:
raise ValueError(f"Unsupported interface translation mode value: {value}")
return normalized # type: ignore[return-value]
def _env_bool(name: str, *, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
normalized = raw.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
raise ValueError(f"Environment variable {name} must be a boolean value, got {raw!r}")
|