trenches / backend /src /trenches_env /provider_runtime.py
Codex
sync main snapshot for HF Space
1794757
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from datetime import datetime, timezone
from time import perf_counter
from typing import Any
import httpx
from trenches_env.models import (
AgentAction,
AgentObservation,
EntityModelBinding,
ExternalSignal,
ProviderAgentDiagnostics,
)
from trenches_env.rl import AGENT_ALLOWED_ACTIONS
_OPENAI_COMPATIBLE_PROVIDERS = {"openai", "openrouter", "huggingface", "ollama", "vllm", "custom"}
_DEFAULT_BASE_URLS = {
"openai": "https://api.openai.com/v1",
"openrouter": "https://openrouter.ai/api/v1",
"huggingface": "https://router.huggingface.co/v1",
"ollama": "http://127.0.0.1:11434/v1",
"vllm": "http://127.0.0.1:8000/v1",
}
_RETRYABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
_HF_ROUTING_POLICIES = {"fastest", "cheapest", "preferred"}
class ProviderDecisionError(RuntimeError):
pass
@dataclass(frozen=True)
class ProviderDecisionRequest:
agent_id: str
binding: EntityModelBinding
observation: AgentObservation
external_signals: list[ExternalSignal]
@dataclass
class _ProviderRuntimeStats:
request_count: int = 0
success_count: int = 0
error_count: int = 0
consecutive_failures: int = 0
total_latency_ms: float = 0.0
last_latency_ms: float | None = None
last_success_at: datetime | None = None
last_error_at: datetime | None = None
last_error: str | None = None
class ProviderDecisionRuntime:
def __init__(
self,
client: httpx.Client | None = None,
timeout_seconds: float = 20.0,
max_attempts: int = 2,
) -> None:
self._client = client or httpx.Client(timeout=timeout_seconds)
self._owns_client = client is None
self._max_attempts = max(1, max_attempts)
self._stats: dict[str, _ProviderRuntimeStats] = {}
def close(self) -> None:
if self._owns_client:
self._client.close()
def decide_action(self, request: ProviderDecisionRequest) -> AgentAction:
binding = request.binding
if not binding.ready_for_inference:
raise ProviderDecisionError("binding is not ready for inference")
stats = self._stats.setdefault(request.agent_id, _ProviderRuntimeStats())
stats.request_count += 1
last_error: str | None = None
for attempt in range(1, self._max_attempts + 1):
started = perf_counter()
try:
payload = self._request_payload(request)
action = self._payload_to_action(request.agent_id, binding, payload)
latency_ms = round((perf_counter() - started) * 1000.0, 2)
self._record_success(stats, latency_ms)
action.metadata.setdefault("provider_attempts", attempt)
return action
except ProviderDecisionError as exc:
latency_ms = round((perf_counter() - started) * 1000.0, 2)
last_error = str(exc)
if attempt >= self._max_attempts or not self._is_retryable_error(exc):
self._record_error(stats, latency_ms, last_error)
raise
except httpx.RequestError as exc:
latency_ms = round((perf_counter() - started) * 1000.0, 2)
last_error = f"provider network error: {exc}"
if attempt >= self._max_attempts:
self._record_error(stats, latency_ms, last_error)
raise ProviderDecisionError(last_error) from exc
self._record_error(stats, 0.0, last_error or "provider decision failed")
raise ProviderDecisionError(last_error or "provider decision failed")
def diagnostics(self, bindings: dict[str, EntityModelBinding]) -> list[ProviderAgentDiagnostics]:
diagnostics: list[ProviderAgentDiagnostics] = []
for agent_id, binding in bindings.items():
stats = self._stats.get(agent_id, _ProviderRuntimeStats())
diagnostics.append(
ProviderAgentDiagnostics(
agent_id=agent_id,
provider=binding.provider,
model_name=binding.model_name,
configured=binding.configured,
ready_for_inference=binding.ready_for_inference,
decision_mode=binding.decision_mode,
status=self._status_for(binding, stats),
request_count=stats.request_count,
success_count=stats.success_count,
error_count=stats.error_count,
consecutive_failures=stats.consecutive_failures,
last_latency_ms=stats.last_latency_ms,
avg_latency_ms=round(stats.total_latency_ms / stats.success_count, 2) if stats.success_count else None,
last_success_at=stats.last_success_at,
last_error_at=stats.last_error_at,
last_error=stats.last_error,
)
)
return diagnostics
def _request_payload(self, request: ProviderDecisionRequest) -> dict[str, Any]:
provider = request.binding.provider
if provider in _OPENAI_COMPATIBLE_PROVIDERS:
return self._request_openai_compatible(request)
if provider == "anthropic":
return self._request_anthropic(request)
raise ProviderDecisionError(f"unsupported provider: {provider}")
def _request_openai_compatible(self, request: ProviderDecisionRequest) -> dict[str, Any]:
binding = request.binding
base_url = (binding.base_url or _DEFAULT_BASE_URLS.get(binding.provider) or "").rstrip("/")
if not base_url:
raise ProviderDecisionError("missing base_url for provider binding")
url = f"{base_url}/chat/completions"
headers = {"Content-Type": "application/json"}
api_key = self._resolve_api_key(binding)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if binding.provider == "openrouter":
headers.setdefault("HTTP-Referer", "https://trenches.local")
headers.setdefault("X-Title", "Trenches")
body: dict[str, Any] = {
"model": self._resolved_model_name(binding),
"temperature": 0.1,
"messages": self._messages(request),
}
if binding.supports_tool_calls:
body["tools"] = [self._openai_emit_action_tool(request.agent_id)]
body["tool_choice"] = {
"type": "function",
"function": {"name": "emit_action"},
}
response = self._client.post(url, headers=headers, json=body)
self._raise_for_status(response)
payload = response.json()
choices = payload.get("choices") or []
if not choices:
raise ProviderDecisionError("provider returned no choices")
message = choices[0].get("message") or {}
tool_calls = message.get("tool_calls") or []
if tool_calls:
arguments = tool_calls[0].get("function", {}).get("arguments", "{}")
return self._parse_json_payload(arguments)
content = message.get("content")
if isinstance(content, list):
content = "".join(
block.get("text", "")
for block in content
if isinstance(block, dict)
)
if not isinstance(content, str) or not content.strip():
raise ProviderDecisionError("provider returned empty message content")
return self._parse_json_payload(content)
@staticmethod
def _resolved_model_name(binding: EntityModelBinding) -> str:
if binding.provider != "huggingface":
return binding.model_name
policy = (os.getenv("TRENCHES_HF_ROUTING_POLICY") or "fastest").strip().lower()
if ":" in binding.model_name or policy not in _HF_ROUTING_POLICIES:
return binding.model_name
return f"{binding.model_name}:{policy}"
def _request_anthropic(self, request: ProviderDecisionRequest) -> dict[str, Any]:
binding = request.binding
base_url = (binding.base_url or "https://api.anthropic.com/v1").rstrip("/")
api_key = self._resolve_api_key(binding)
if not api_key:
raise ProviderDecisionError("anthropic provider requires an API key")
body: dict[str, Any] = {
"model": binding.model_name,
"max_tokens": 350,
"temperature": 0.1,
"system": self._system_prompt(request.agent_id),
"messages": [{"role": "user", "content": self._user_prompt(request)}],
}
if binding.supports_tool_calls:
body["tools"] = [self._anthropic_emit_action_tool(request.agent_id)]
response = self._client.post(
f"{base_url}/messages",
headers={
"Content-Type": "application/json",
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
},
json=body,
)
self._raise_for_status(response)
payload = response.json()
content = payload.get("content") or []
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
return block.get("input") or {}
text_blocks = [
block.get("text", "")
for block in content
if isinstance(block, dict) and block.get("type") == "text"
]
if not text_blocks:
raise ProviderDecisionError("anthropic provider returned no usable content")
return self._parse_json_payload("\n".join(text_blocks))
@staticmethod
def _resolve_api_key(binding: EntityModelBinding) -> str | None:
if not binding.api_key_env:
return None
value = os.getenv(binding.api_key_env)
return value.strip() if value else None
@staticmethod
def _raise_for_status(response: httpx.Response) -> None:
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
message = f"provider returned HTTP {exc.response.status_code}"
raise ProviderDecisionError(message) from exc
@staticmethod
def _messages(request: ProviderDecisionRequest) -> list[dict[str, Any]]:
return [
{"role": "system", "content": ProviderDecisionRuntime._system_prompt(request.agent_id)},
{"role": "user", "content": ProviderDecisionRuntime._user_prompt(request)},
]
@staticmethod
def _system_prompt(agent_id: str) -> str:
return (
f"You are the decision runtime for {agent_id} in a geopolitical simulation. "
"Choose exactly one legal action. Do not invent actions, targets, or tools that are not provided. "
"If using text output, return strict JSON with keys type, target, and summary."
)
@staticmethod
def _user_prompt(request: ProviderDecisionRequest) -> str:
observation = request.observation
payload = {
"decision_prompt": observation.decision_prompt,
"available_actions": observation.available_actions,
"projection": observation.projection.model_dump(mode="json"),
"public_brief": [brief.model_dump(mode="json") for brief in observation.public_brief[:4]],
"private_brief": [brief.model_dump(mode="json") for brief in observation.private_brief[:6]],
"strategic_state": observation.strategic_state,
"asset_alerts": observation.asset_alerts[:6],
"available_data_sources": [source.model_dump(mode="json") for source in observation.available_data_sources[:8]],
"external_signals": [signal.model_dump(mode="json") for signal in request.external_signals[:6]],
}
return json.dumps(payload, ensure_ascii=True)
@staticmethod
def _openai_emit_action_tool(agent_id: str) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": "emit_action",
"description": f"Emit exactly one legal action for {agent_id}.",
"parameters": ProviderDecisionRuntime._action_schema(agent_id),
},
}
@staticmethod
def _anthropic_emit_action_tool(agent_id: str) -> dict[str, Any]:
return {
"name": "emit_action",
"description": f"Emit exactly one legal action for {agent_id}.",
"input_schema": ProviderDecisionRuntime._action_schema(agent_id),
}
@staticmethod
def _action_schema(agent_id: str) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": list(AGENT_ALLOWED_ACTIONS.get(agent_id, ())),
},
"target": {
"type": ["string", "null"],
},
"summary": {
"type": "string",
"minLength": 8,
},
},
"required": ["type", "summary"],
"additionalProperties": False,
}
@staticmethod
def _parse_json_payload(raw: str) -> dict[str, Any]:
text = raw.strip()
if text.startswith("```"):
lines = [line for line in text.splitlines() if not line.startswith("```")]
text = "\n".join(lines).strip()
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
raise ProviderDecisionError(f"provider returned invalid JSON: {exc}") from exc
if not isinstance(payload, dict):
raise ProviderDecisionError("provider payload must be a JSON object")
return payload
@staticmethod
def _payload_to_action(agent_id: str, binding: EntityModelBinding, payload: dict[str, Any]) -> AgentAction:
action_type = payload.get("type")
summary = str(payload.get("summary", "")).strip()
target = payload.get("target")
if not isinstance(action_type, str) or action_type not in AGENT_ALLOWED_ACTIONS.get(agent_id, ()):
raise ProviderDecisionError(f"provider selected illegal action for {agent_id}: {action_type}")
if not summary:
raise ProviderDecisionError("provider did not return an action summary")
if target is not None and not isinstance(target, str):
raise ProviderDecisionError("provider target must be a string or null")
return AgentAction(
actor=agent_id,
type=action_type,
target=target,
summary=summary,
metadata={
"mode": "provider_inference",
"provider": binding.provider,
"model": binding.model_name,
},
)
@staticmethod
def _is_retryable_error(error: ProviderDecisionError) -> bool:
message = str(error).lower()
if "timeout" in message or "timed out" in message:
return True
if "connection" in message or "network" in message:
return True
if "http " in message:
for status_code in _RETRYABLE_STATUS_CODES:
if f"http {status_code}" in message:
return True
return False
@staticmethod
def _record_success(stats: _ProviderRuntimeStats, latency_ms: float) -> None:
stats.success_count += 1
stats.consecutive_failures = 0
stats.last_latency_ms = latency_ms
stats.total_latency_ms += latency_ms
stats.last_success_at = datetime.now(timezone.utc)
stats.last_error = None
@staticmethod
def _record_error(stats: _ProviderRuntimeStats, latency_ms: float, error: str) -> None:
stats.error_count += 1
stats.consecutive_failures += 1
stats.last_latency_ms = latency_ms if latency_ms > 0.0 else stats.last_latency_ms
stats.last_error = error
stats.last_error_at = datetime.now(timezone.utc)
@staticmethod
def _status_for(binding: EntityModelBinding, stats: _ProviderRuntimeStats) -> str:
if not binding.ready_for_inference:
return "fallback_only"
if stats.request_count == 0:
return "idle"
if stats.consecutive_failures > 0:
return "degraded"
return "healthy"