File size: 6,235 Bytes
2803d7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Literal

from .llm_client import (
    DEFAULT_GEMINI_DM_MODEL,
    DEFAULT_GEMINI_HERO_MODEL,
    DEFAULT_HF_DM_MODEL,
    DEFAULT_HF_HERO_MODEL,
    GeminiStructuredClient,
    HuggingFaceStructuredClient,
    PROVIDER_GEMINI,
    PROVIDER_HF_LOCAL,
    StructuredModelClient,
)

StructuredProvider = Literal["gemini", "hf_local"]
InterfaceProvider = Literal["strict", "simple", "gemini"]
InterfaceTranslationMode = Literal["none", "corporate_app"]
RoleName = Literal["dm", "hero"]

DEFAULT_INTERFACE_PROVIDER: InterfaceProvider = "strict"
DEFAULT_INTERFACE_MODEL = "gemini-2.5-flash-lite"
DEFAULT_INTERFACE_TRANSLATION_MODE: InterfaceTranslationMode = "none"


@dataclass(frozen=True)
class StructuredClientConfig:
    role: RoleName
    provider: StructuredProvider
    model_name: str
    adapter_path: str | None = None
    cache_dir: str | None = None
    load_in_4bit: bool = True
    trust_remote_code: bool = False


@dataclass(frozen=True)
class InterfaceConfig:
    provider: InterfaceProvider
    model_name: str = DEFAULT_INTERFACE_MODEL
    narrate_observations: bool = False
    translation_mode: InterfaceTranslationMode = DEFAULT_INTERFACE_TRANSLATION_MODE


def resolve_structured_client_config(
    role: RoleName,
    *,
    provider: StructuredProvider | None = None,
    model_name: str | None = None,
    adapter_path: str | None = None,
) -> StructuredClientConfig:
    env_prefix = f"DND_{role.upper()}"
    resolved_provider = provider or _structured_provider_from_env(os.getenv(f"{env_prefix}_PROVIDER")) or PROVIDER_GEMINI
    if resolved_provider == PROVIDER_HF_LOCAL:
        default_model = DEFAULT_HF_DM_MODEL if role == "dm" else DEFAULT_HF_HERO_MODEL
    else:
        default_model = DEFAULT_GEMINI_DM_MODEL if role == "dm" else DEFAULT_GEMINI_HERO_MODEL
    return StructuredClientConfig(
        role=role,
        provider=resolved_provider,
        model_name=model_name or os.getenv(f"{env_prefix}_MODEL") or default_model,
        adapter_path=adapter_path or os.getenv(f"{env_prefix}_ADAPTER_PATH"),
        cache_dir=os.getenv("HF_HOME"),
        load_in_4bit=_env_bool("DND_LOAD_IN_4BIT", default=True),
        trust_remote_code=_env_bool("DND_TRUST_REMOTE_CODE", default=False),
    )


def create_structured_client(config: StructuredClientConfig) -> StructuredModelClient:
    if config.provider == PROVIDER_GEMINI:
        return GeminiStructuredClient()
    if config.provider == PROVIDER_HF_LOCAL:
        return HuggingFaceStructuredClient(
            adapter_path=config.adapter_path,
            cache_dir=config.cache_dir,
            load_in_4bit=config.load_in_4bit,
            trust_remote_code=config.trust_remote_code,
        )
    raise ValueError(f"Unsupported structured provider: {config.provider}")


def resolve_interface_config(
    *,
    provider: InterfaceProvider | None = None,
    model_name: str | None = None,
    narrate_observations: bool | None = None,
    translation_mode: InterfaceTranslationMode | None = None,
) -> InterfaceConfig:
    resolved_translation = (
        translation_mode
        or _interface_translation_mode_from_env(os.getenv("DND_INTERFACE_TRANSLATION_MODE"))
        or DEFAULT_INTERFACE_TRANSLATION_MODE
    )
    resolved_provider = provider or _interface_provider_from_env(os.getenv("DND_INTERFACE_PROVIDER"))
    if resolved_provider is None:
        resolved_provider = "gemini" if resolved_translation != "none" else DEFAULT_INTERFACE_PROVIDER
    resolved_narrate = narrate_observations
    if resolved_narrate is None:
        resolved_narrate = _env_bool("DND_INTERFACE_NARRATE", default=False)
    if resolved_translation != "none" and resolved_provider != "gemini":
        raise ValueError("Interface translation mode requires the Gemini interface provider.")
    return InterfaceConfig(
        provider=resolved_provider,
        model_name=model_name or os.getenv("DND_INTERFACE_MODEL") or DEFAULT_INTERFACE_MODEL,
        narrate_observations=resolved_narrate,
        translation_mode=resolved_translation,
    )


def build_interface_adapter(config: InterfaceConfig):
    from agents.master.interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter, StrictCliInterfaceAdapter

    if config.provider == "strict":
        return StrictCliInterfaceAdapter()
    if config.provider == "simple":
        return SimpleInterfaceAdapter()
    if config.provider == "gemini":
        return GeminiInterfaceAdapter(
            model=config.model_name,
            narrate_observations=config.narrate_observations,
            translation_mode=config.translation_mode,
        )
    raise ValueError(f"Unsupported interface provider: {config.provider}")


def _structured_provider_from_env(value: str | None) -> StructuredProvider | None:
    if value is None:
        return None
    normalized = value.strip().lower()
    if normalized not in {PROVIDER_GEMINI, PROVIDER_HF_LOCAL}:
        raise ValueError(f"Unsupported structured provider value: {value}")
    return normalized  # type: ignore[return-value]


def _interface_provider_from_env(value: str | None) -> InterfaceProvider | None:
    if value is None:
        return None
    normalized = value.strip().lower()
    if normalized not in {"strict", "simple", "gemini"}:
        raise ValueError(f"Unsupported interface provider value: {value}")
    return normalized  # type: ignore[return-value]


def _interface_translation_mode_from_env(value: str | None) -> InterfaceTranslationMode | None:
    if value is None:
        return None
    normalized = value.strip().lower()
    if normalized not in {"none", "corporate_app"}:
        raise ValueError(f"Unsupported interface translation mode value: {value}")
    return normalized  # type: ignore[return-value]


def _env_bool(name: str, *, default: bool) -> bool:
    raw = os.getenv(name)
    if raw is None:
        return default
    normalized = raw.strip().lower()
    if normalized in {"1", "true", "yes", "on"}:
        return True
    if normalized in {"0", "false", "no", "off"}:
        return False
    raise ValueError(f"Environment variable {name} must be a boolean value, got {raw!r}")