from __future__ import annotations import json import os from dataclasses import dataclass from datetime import datetime, timezone from time import perf_counter from typing import Any import httpx from trenches_env.models import ( AgentAction, AgentObservation, EntityModelBinding, ExternalSignal, ProviderAgentDiagnostics, ) from trenches_env.rl import AGENT_ALLOWED_ACTIONS _OPENAI_COMPATIBLE_PROVIDERS = {"openai", "openrouter", "huggingface", "ollama", "vllm", "custom"} _DEFAULT_BASE_URLS = { "openai": "https://api.openai.com/v1", "openrouter": "https://openrouter.ai/api/v1", "huggingface": "https://router.huggingface.co/v1", "ollama": "http://127.0.0.1:11434/v1", "vllm": "http://127.0.0.1:8000/v1", } _RETRYABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504} _HF_ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} class ProviderDecisionError(RuntimeError): pass @dataclass(frozen=True) class ProviderDecisionRequest: agent_id: str binding: EntityModelBinding observation: AgentObservation external_signals: list[ExternalSignal] @dataclass class _ProviderRuntimeStats: request_count: int = 0 success_count: int = 0 error_count: int = 0 consecutive_failures: int = 0 total_latency_ms: float = 0.0 last_latency_ms: float | None = None last_success_at: datetime | None = None last_error_at: datetime | None = None last_error: str | None = None class ProviderDecisionRuntime: def __init__( self, client: httpx.Client | None = None, timeout_seconds: float = 20.0, max_attempts: int = 2, ) -> None: self._client = client or httpx.Client(timeout=timeout_seconds) self._owns_client = client is None self._max_attempts = max(1, max_attempts) self._stats: dict[str, _ProviderRuntimeStats] = {} def close(self) -> None: if self._owns_client: self._client.close() def decide_action(self, request: ProviderDecisionRequest) -> AgentAction: binding = request.binding if not binding.ready_for_inference: raise ProviderDecisionError("binding is not ready for inference") stats = self._stats.setdefault(request.agent_id, _ProviderRuntimeStats()) stats.request_count += 1 last_error: str | None = None for attempt in range(1, self._max_attempts + 1): started = perf_counter() try: payload = self._request_payload(request) action = self._payload_to_action(request.agent_id, binding, payload) latency_ms = round((perf_counter() - started) * 1000.0, 2) self._record_success(stats, latency_ms) action.metadata.setdefault("provider_attempts", attempt) return action except ProviderDecisionError as exc: latency_ms = round((perf_counter() - started) * 1000.0, 2) last_error = str(exc) if attempt >= self._max_attempts or not self._is_retryable_error(exc): self._record_error(stats, latency_ms, last_error) raise except httpx.RequestError as exc: latency_ms = round((perf_counter() - started) * 1000.0, 2) last_error = f"provider network error: {exc}" if attempt >= self._max_attempts: self._record_error(stats, latency_ms, last_error) raise ProviderDecisionError(last_error) from exc self._record_error(stats, 0.0, last_error or "provider decision failed") raise ProviderDecisionError(last_error or "provider decision failed") def diagnostics(self, bindings: dict[str, EntityModelBinding]) -> list[ProviderAgentDiagnostics]: diagnostics: list[ProviderAgentDiagnostics] = [] for agent_id, binding in bindings.items(): stats = self._stats.get(agent_id, _ProviderRuntimeStats()) diagnostics.append( ProviderAgentDiagnostics( agent_id=agent_id, provider=binding.provider, model_name=binding.model_name, configured=binding.configured, ready_for_inference=binding.ready_for_inference, decision_mode=binding.decision_mode, status=self._status_for(binding, stats), request_count=stats.request_count, success_count=stats.success_count, error_count=stats.error_count, consecutive_failures=stats.consecutive_failures, last_latency_ms=stats.last_latency_ms, avg_latency_ms=round(stats.total_latency_ms / stats.success_count, 2) if stats.success_count else None, last_success_at=stats.last_success_at, last_error_at=stats.last_error_at, last_error=stats.last_error, ) ) return diagnostics def _request_payload(self, request: ProviderDecisionRequest) -> dict[str, Any]: provider = request.binding.provider if provider in _OPENAI_COMPATIBLE_PROVIDERS: return self._request_openai_compatible(request) if provider == "anthropic": return self._request_anthropic(request) raise ProviderDecisionError(f"unsupported provider: {provider}") def _request_openai_compatible(self, request: ProviderDecisionRequest) -> dict[str, Any]: binding = request.binding base_url = (binding.base_url or _DEFAULT_BASE_URLS.get(binding.provider) or "").rstrip("/") if not base_url: raise ProviderDecisionError("missing base_url for provider binding") url = f"{base_url}/chat/completions" headers = {"Content-Type": "application/json"} api_key = self._resolve_api_key(binding) if api_key: headers["Authorization"] = f"Bearer {api_key}" if binding.provider == "openrouter": headers.setdefault("HTTP-Referer", "https://trenches.local") headers.setdefault("X-Title", "Trenches") body: dict[str, Any] = { "model": self._resolved_model_name(binding), "temperature": 0.1, "messages": self._messages(request), } if binding.supports_tool_calls: body["tools"] = [self._openai_emit_action_tool(request.agent_id)] body["tool_choice"] = { "type": "function", "function": {"name": "emit_action"}, } response = self._client.post(url, headers=headers, json=body) self._raise_for_status(response) payload = response.json() choices = payload.get("choices") or [] if not choices: raise ProviderDecisionError("provider returned no choices") message = choices[0].get("message") or {} tool_calls = message.get("tool_calls") or [] if tool_calls: arguments = tool_calls[0].get("function", {}).get("arguments", "{}") return self._parse_json_payload(arguments) content = message.get("content") if isinstance(content, list): content = "".join( block.get("text", "") for block in content if isinstance(block, dict) ) if not isinstance(content, str) or not content.strip(): raise ProviderDecisionError("provider returned empty message content") return self._parse_json_payload(content) @staticmethod def _resolved_model_name(binding: EntityModelBinding) -> str: if binding.provider != "huggingface": return binding.model_name policy = (os.getenv("TRENCHES_HF_ROUTING_POLICY") or "fastest").strip().lower() if ":" in binding.model_name or policy not in _HF_ROUTING_POLICIES: return binding.model_name return f"{binding.model_name}:{policy}" def _request_anthropic(self, request: ProviderDecisionRequest) -> dict[str, Any]: binding = request.binding base_url = (binding.base_url or "https://api.anthropic.com/v1").rstrip("/") api_key = self._resolve_api_key(binding) if not api_key: raise ProviderDecisionError("anthropic provider requires an API key") body: dict[str, Any] = { "model": binding.model_name, "max_tokens": 350, "temperature": 0.1, "system": self._system_prompt(request.agent_id), "messages": [{"role": "user", "content": self._user_prompt(request)}], } if binding.supports_tool_calls: body["tools"] = [self._anthropic_emit_action_tool(request.agent_id)] response = self._client.post( f"{base_url}/messages", headers={ "Content-Type": "application/json", "x-api-key": api_key, "anthropic-version": "2023-06-01", }, json=body, ) self._raise_for_status(response) payload = response.json() content = payload.get("content") or [] for block in content: if isinstance(block, dict) and block.get("type") == "tool_use": return block.get("input") or {} text_blocks = [ block.get("text", "") for block in content if isinstance(block, dict) and block.get("type") == "text" ] if not text_blocks: raise ProviderDecisionError("anthropic provider returned no usable content") return self._parse_json_payload("\n".join(text_blocks)) @staticmethod def _resolve_api_key(binding: EntityModelBinding) -> str | None: if not binding.api_key_env: return None value = os.getenv(binding.api_key_env) return value.strip() if value else None @staticmethod def _raise_for_status(response: httpx.Response) -> None: try: response.raise_for_status() except httpx.HTTPStatusError as exc: message = f"provider returned HTTP {exc.response.status_code}" raise ProviderDecisionError(message) from exc @staticmethod def _messages(request: ProviderDecisionRequest) -> list[dict[str, Any]]: return [ {"role": "system", "content": ProviderDecisionRuntime._system_prompt(request.agent_id)}, {"role": "user", "content": ProviderDecisionRuntime._user_prompt(request)}, ] @staticmethod def _system_prompt(agent_id: str) -> str: return ( f"You are the decision runtime for {agent_id} in a geopolitical simulation. " "Choose exactly one legal action. Do not invent actions, targets, or tools that are not provided. " "If using text output, return strict JSON with keys type, target, and summary." ) @staticmethod def _user_prompt(request: ProviderDecisionRequest) -> str: observation = request.observation payload = { "decision_prompt": observation.decision_prompt, "available_actions": observation.available_actions, "projection": observation.projection.model_dump(mode="json"), "public_brief": [brief.model_dump(mode="json") for brief in observation.public_brief[:4]], "private_brief": [brief.model_dump(mode="json") for brief in observation.private_brief[:6]], "strategic_state": observation.strategic_state, "asset_alerts": observation.asset_alerts[:6], "available_data_sources": [source.model_dump(mode="json") for source in observation.available_data_sources[:8]], "external_signals": [signal.model_dump(mode="json") for signal in request.external_signals[:6]], } return json.dumps(payload, ensure_ascii=True) @staticmethod def _openai_emit_action_tool(agent_id: str) -> dict[str, Any]: return { "type": "function", "function": { "name": "emit_action", "description": f"Emit exactly one legal action for {agent_id}.", "parameters": ProviderDecisionRuntime._action_schema(agent_id), }, } @staticmethod def _anthropic_emit_action_tool(agent_id: str) -> dict[str, Any]: return { "name": "emit_action", "description": f"Emit exactly one legal action for {agent_id}.", "input_schema": ProviderDecisionRuntime._action_schema(agent_id), } @staticmethod def _action_schema(agent_id: str) -> dict[str, Any]: return { "type": "object", "properties": { "type": { "type": "string", "enum": list(AGENT_ALLOWED_ACTIONS.get(agent_id, ())), }, "target": { "type": ["string", "null"], }, "summary": { "type": "string", "minLength": 8, }, }, "required": ["type", "summary"], "additionalProperties": False, } @staticmethod def _parse_json_payload(raw: str) -> dict[str, Any]: text = raw.strip() if text.startswith("```"): lines = [line for line in text.splitlines() if not line.startswith("```")] text = "\n".join(lines).strip() try: payload = json.loads(text) except json.JSONDecodeError as exc: raise ProviderDecisionError(f"provider returned invalid JSON: {exc}") from exc if not isinstance(payload, dict): raise ProviderDecisionError("provider payload must be a JSON object") return payload @staticmethod def _payload_to_action(agent_id: str, binding: EntityModelBinding, payload: dict[str, Any]) -> AgentAction: action_type = payload.get("type") summary = str(payload.get("summary", "")).strip() target = payload.get("target") if not isinstance(action_type, str) or action_type not in AGENT_ALLOWED_ACTIONS.get(agent_id, ()): raise ProviderDecisionError(f"provider selected illegal action for {agent_id}: {action_type}") if not summary: raise ProviderDecisionError("provider did not return an action summary") if target is not None and not isinstance(target, str): raise ProviderDecisionError("provider target must be a string or null") return AgentAction( actor=agent_id, type=action_type, target=target, summary=summary, metadata={ "mode": "provider_inference", "provider": binding.provider, "model": binding.model_name, }, ) @staticmethod def _is_retryable_error(error: ProviderDecisionError) -> bool: message = str(error).lower() if "timeout" in message or "timed out" in message: return True if "connection" in message or "network" in message: return True if "http " in message: for status_code in _RETRYABLE_STATUS_CODES: if f"http {status_code}" in message: return True return False @staticmethod def _record_success(stats: _ProviderRuntimeStats, latency_ms: float) -> None: stats.success_count += 1 stats.consecutive_failures = 0 stats.last_latency_ms = latency_ms stats.total_latency_ms += latency_ms stats.last_success_at = datetime.now(timezone.utc) stats.last_error = None @staticmethod def _record_error(stats: _ProviderRuntimeStats, latency_ms: float, error: str) -> None: stats.error_count += 1 stats.consecutive_failures += 1 stats.last_latency_ms = latency_ms if latency_ms > 0.0 else stats.last_latency_ms stats.last_error = error stats.last_error_at = datetime.now(timezone.utc) @staticmethod def _status_for(binding: EntityModelBinding, stats: _ProviderRuntimeStats) -> str: if not binding.ready_for_inference: return "fallback_only" if stats.request_count == 0: return "idle" if stats.consecutive_failures > 0: return "degraded" return "healthy"