File size: 3,940 Bytes
aceb1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Agent Proxy Base Module

Provides the abstract base class and data structures for agent proxies,
plus a factory registry for creating proxy instances from configuration.

Agent proxies allow annotators to interact with AI agents live during
annotation tasks. Each proxy type (echo, http, openai) handles
communication with a specific kind of agent backend.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List
import logging
import time

logger = logging.getLogger(__name__)


@dataclass
class AgentMessage:
    """A single message in an agent conversation."""
    role: str  # "user", "agent", "system", "error"
    content: str
    timestamp: float = field(default_factory=time.time)
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class AgentResponse:
    """Response from an agent proxy after sending a message."""
    message: AgentMessage
    done: bool = False
    error: Optional[str] = None


class BaseAgentProxy(ABC):
    """
    Abstract base class for agent proxies.

    Subclasses implement communication with specific agent backends
    (echo for testing, HTTP for generic REST APIs, OpenAI for chat completions).
    """

    proxy_type: str = ""

    def __init__(self, config: dict):
        self.config = config
        self._initialize()

    @abstractmethod
    def _initialize(self):
        """Set up connections, validate config. Called by __init__."""
        pass

    @abstractmethod
    def start_session(self, task_description: str) -> dict:
        """
        Start a new interaction session.

        Args:
            task_description: The task the annotator should accomplish with the agent.

        Returns:
            Proxy-specific session context dict (stored in AgentSession.proxy_context).
        """
        pass

    @abstractmethod
    def send_message(self, message: str, session_context: dict) -> AgentResponse:
        """
        Send a message to the agent and get a blocking response.

        Args:
            message: The user's message text.
            session_context: The proxy-specific context from start_session.

        Returns:
            AgentResponse with the agent's reply.
        """
        pass

    def end_session(self, session_context: dict):
        """
        Clean up session resources. Override if needed.

        Args:
            session_context: The proxy-specific context from start_session.
        """
        pass


class AgentProxyFactory:
    """Factory registry for creating agent proxy instances."""

    _proxies: Dict[str, type] = {}

    @classmethod
    def register(cls, proxy_type: str, proxy_class: type):
        """Register a proxy type."""
        cls._proxies[proxy_type] = proxy_class
        logger.debug(f"Registered agent proxy type: {proxy_type}")

    @classmethod
    def create(cls, config: dict) -> BaseAgentProxy:
        """
        Create an agent proxy from configuration.

        Args:
            config: The full config dict. Reads from config["agent_proxy"].

        Returns:
            Configured BaseAgentProxy instance.

        Raises:
            ValueError: If proxy type is unknown or missing.
        """
        agent_config = config.get("agent_proxy", {})
        proxy_type = agent_config.get("type")

        if not proxy_type:
            raise ValueError("agent_proxy.type is required")

        if proxy_type not in cls._proxies:
            supported = ", ".join(sorted(cls._proxies.keys()))
            raise ValueError(
                f"Unknown agent proxy type: '{proxy_type}'. "
                f"Supported types: {supported}"
            )

        proxy_class = cls._proxies[proxy_type]
        return proxy_class(agent_config)

    @classmethod
    def get_supported_types(cls) -> List[str]:
        """Get list of registered proxy type names."""
        return sorted(cls._proxies.keys())