File size: 5,467 Bytes
399b80c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
provider is to manage sessions of a backend, if the backend is mcp, then provider will manage sessions through servers
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Generic, TypeVar

from .tool import BaseTool
from .types import BackendType, SessionConfig, ToolResult, ToolStatus
from .session import BaseSession
from .security.policies import SecurityPolicyManager
from openspace.config import get_config
from openspace.utils.logging import Logger

logger = Logger.get_logger(__name__)
TSession = TypeVar('TSession', bound=BaseSession)


class Provider(ABC, Generic[TSession]):
    """Backend provider base class"""  
    def __init__(self, backend_type: BackendType, config: Dict[str, Any] = None):
        self.backend_type = backend_type
        self.config = config or {}
        self.is_initialized = False
        self._sessions: Dict[str, TSession] = {}  # session management
        self._session_counter: int = 0
        self.security_manager = SecurityPolicyManager()
        
        self._setup_security_policy(config)
        
    def _setup_security_policy(self, config: dict | None = None):   
        security_policy = get_config().get_security_policy(self.backend_type.value)
        self.security_manager.set_backend_policy(BackendType.SHELL, security_policy)
        
    async def ensure_initialized(self) -> None:
        """
         Internal helper.  Guarantee that `initialize()` has been executed
        """
        if not self.is_initialized:
            await self.initialize()
        
    @abstractmethod
    async def initialize(self) -> None:
        """Initialize provider, call `create_session` to create all sessions if not exist        
        Subclasses should set `self.is_initialized = True` after successful initialization
        """
        pass
    
    @abstractmethod
    async def create_session(self, session_config: SessionConfig) -> TSession:
        """Create session, update _sessions"""
        pass

    @abstractmethod
    async def close_session(self, session_name: str) -> None:
        """Close session"""
        pass
    
    def list_sessions(self) -> List[str]:
        """Get all session IDs"""
        return list(self._sessions.keys())
    
    def get_session(self, session_name: str) -> Optional[TSession]:
        """Get session object by ID"""
        return self._sessions.get(session_name)
    
    async def close_all_sessions(self) -> None:
        """Provider shutdown cleanup"""
        for session_name in list(self._sessions.keys()):
            try:
                await self.close_session(session_name)
            except Exception as e:
                print(f"Error closing session {session_name}: {e}")
        
        self._sessions.clear()
        self.is_initialized = False

    def __repr__(self) -> str:
        return (f"Provider(backend={self.backend_type.value}, "
                f"initialized={self.is_initialized}, "
                f"sessions={len(self._sessions)}, "
                f"config_items={len(self.config)})")
        
    async def list_tools(self, session_name: Optional[str] = None) -> List[BaseTool]:
        """
        Return BaseTool list.
        If session_name is specified, only return the tools of the specified session. 
        If session_name is not specified, return all tools of all sessions.
        """
        await self.ensure_initialized()
        
        if session_name:
            session = self._sessions.get(session_name)
            return await session.list_tools() if session else []

        tools: list[BaseTool] = []
        for sess in self._sessions.values():
            tools.extend(await sess.list_tools())
        return tools
    
    async def call_tool(
        self,
        session_name: str,
        tool_name: str,
        parameters: Dict[str, Any] | None = None,
    ) -> ToolResult:
        
        await self.ensure_initialized()
        parameters = parameters or {}

        session = self._sessions.get(session_name)
        if session is None:
            return ToolResult(
                status=ToolStatus.ERROR,
                content="",
                error=f"Session '{session_name}' not found",
                metadata={"session_name": session_name, "tool_name": tool_name},
            )

        try:
            return await session.call_tool(tool_name, parameters)
        except Exception as e:
            logger.error("Execute tool error: %s @%s - %s", tool_name, session_name, e)
            return ToolResult(
                status=ToolStatus.ERROR,
                content="",
                error=str(e),
                metadata={"session_name": session_name, "tool_name": tool_name},
            )


class ProviderRegistry:
    """
    Maintain mapping of BackendType -> Provider, and provide dynamic registration / retrieval capabilities
    """
    def __init__(self) -> None:
        self._providers: dict[BackendType, Provider] = {}

    def register(self, provider: "Provider") -> None:
        self._providers[provider.backend_type] = provider
        logger.debug("Provider for %s registered", provider.backend_type)

    def get(self, backend: BackendType) -> "Provider":
        if backend not in self._providers: 
            raise KeyError(f"Provider for '{backend.value}' not registered")
        return self._providers[backend]

    def list(self) -> dict[BackendType, "Provider"]:
        return dict(self._providers)