FATHOM-DM / agents /shared /runtime.py
aarushgupta's picture
Deploy FATHOM-DM Space bundle
2803d7e verified
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Literal
from .llm_client import (
DEFAULT_GEMINI_DM_MODEL,
DEFAULT_GEMINI_HERO_MODEL,
DEFAULT_HF_DM_MODEL,
DEFAULT_HF_HERO_MODEL,
GeminiStructuredClient,
HuggingFaceStructuredClient,
PROVIDER_GEMINI,
PROVIDER_HF_LOCAL,
StructuredModelClient,
)
StructuredProvider = Literal["gemini", "hf_local"]
InterfaceProvider = Literal["strict", "simple", "gemini"]
InterfaceTranslationMode = Literal["none", "corporate_app"]
RoleName = Literal["dm", "hero"]
DEFAULT_INTERFACE_PROVIDER: InterfaceProvider = "strict"
DEFAULT_INTERFACE_MODEL = "gemini-2.5-flash-lite"
DEFAULT_INTERFACE_TRANSLATION_MODE: InterfaceTranslationMode = "none"
@dataclass(frozen=True)
class StructuredClientConfig:
role: RoleName
provider: StructuredProvider
model_name: str
adapter_path: str | None = None
cache_dir: str | None = None
load_in_4bit: bool = True
trust_remote_code: bool = False
@dataclass(frozen=True)
class InterfaceConfig:
provider: InterfaceProvider
model_name: str = DEFAULT_INTERFACE_MODEL
narrate_observations: bool = False
translation_mode: InterfaceTranslationMode = DEFAULT_INTERFACE_TRANSLATION_MODE
def resolve_structured_client_config(
role: RoleName,
*,
provider: StructuredProvider | None = None,
model_name: str | None = None,
adapter_path: str | None = None,
) -> StructuredClientConfig:
env_prefix = f"DND_{role.upper()}"
resolved_provider = provider or _structured_provider_from_env(os.getenv(f"{env_prefix}_PROVIDER")) or PROVIDER_GEMINI
if resolved_provider == PROVIDER_HF_LOCAL:
default_model = DEFAULT_HF_DM_MODEL if role == "dm" else DEFAULT_HF_HERO_MODEL
else:
default_model = DEFAULT_GEMINI_DM_MODEL if role == "dm" else DEFAULT_GEMINI_HERO_MODEL
return StructuredClientConfig(
role=role,
provider=resolved_provider,
model_name=model_name or os.getenv(f"{env_prefix}_MODEL") or default_model,
adapter_path=adapter_path or os.getenv(f"{env_prefix}_ADAPTER_PATH"),
cache_dir=os.getenv("HF_HOME"),
load_in_4bit=_env_bool("DND_LOAD_IN_4BIT", default=True),
trust_remote_code=_env_bool("DND_TRUST_REMOTE_CODE", default=False),
)
def create_structured_client(config: StructuredClientConfig) -> StructuredModelClient:
if config.provider == PROVIDER_GEMINI:
return GeminiStructuredClient()
if config.provider == PROVIDER_HF_LOCAL:
return HuggingFaceStructuredClient(
adapter_path=config.adapter_path,
cache_dir=config.cache_dir,
load_in_4bit=config.load_in_4bit,
trust_remote_code=config.trust_remote_code,
)
raise ValueError(f"Unsupported structured provider: {config.provider}")
def resolve_interface_config(
*,
provider: InterfaceProvider | None = None,
model_name: str | None = None,
narrate_observations: bool | None = None,
translation_mode: InterfaceTranslationMode | None = None,
) -> InterfaceConfig:
resolved_translation = (
translation_mode
or _interface_translation_mode_from_env(os.getenv("DND_INTERFACE_TRANSLATION_MODE"))
or DEFAULT_INTERFACE_TRANSLATION_MODE
)
resolved_provider = provider or _interface_provider_from_env(os.getenv("DND_INTERFACE_PROVIDER"))
if resolved_provider is None:
resolved_provider = "gemini" if resolved_translation != "none" else DEFAULT_INTERFACE_PROVIDER
resolved_narrate = narrate_observations
if resolved_narrate is None:
resolved_narrate = _env_bool("DND_INTERFACE_NARRATE", default=False)
if resolved_translation != "none" and resolved_provider != "gemini":
raise ValueError("Interface translation mode requires the Gemini interface provider.")
return InterfaceConfig(
provider=resolved_provider,
model_name=model_name or os.getenv("DND_INTERFACE_MODEL") or DEFAULT_INTERFACE_MODEL,
narrate_observations=resolved_narrate,
translation_mode=resolved_translation,
)
def build_interface_adapter(config: InterfaceConfig):
from agents.master.interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter, StrictCliInterfaceAdapter
if config.provider == "strict":
return StrictCliInterfaceAdapter()
if config.provider == "simple":
return SimpleInterfaceAdapter()
if config.provider == "gemini":
return GeminiInterfaceAdapter(
model=config.model_name,
narrate_observations=config.narrate_observations,
translation_mode=config.translation_mode,
)
raise ValueError(f"Unsupported interface provider: {config.provider}")
def _structured_provider_from_env(value: str | None) -> StructuredProvider | None:
if value is None:
return None
normalized = value.strip().lower()
if normalized not in {PROVIDER_GEMINI, PROVIDER_HF_LOCAL}:
raise ValueError(f"Unsupported structured provider value: {value}")
return normalized # type: ignore[return-value]
def _interface_provider_from_env(value: str | None) -> InterfaceProvider | None:
if value is None:
return None
normalized = value.strip().lower()
if normalized not in {"strict", "simple", "gemini"}:
raise ValueError(f"Unsupported interface provider value: {value}")
return normalized # type: ignore[return-value]
def _interface_translation_mode_from_env(value: str | None) -> InterfaceTranslationMode | None:
if value is None:
return None
normalized = value.strip().lower()
if normalized not in {"none", "corporate_app"}:
raise ValueError(f"Unsupported interface translation mode value: {value}")
return normalized # type: ignore[return-value]
def _env_bool(name: str, *, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
normalized = raw.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
raise ValueError(f"Environment variable {name} must be a boolean value, got {raw!r}")