| import importlib |
| import sys |
| from typing import Dict |
|
|
| import ray |
|
|
| from lagent.schema import AgentMessage |
| from lagent.utils import load_class_from_string |
|
|
|
|
| class AsyncAgentRayActor: |
|
|
| def __init__( |
| self, |
| config: Dict, |
| num_gpus: int, |
| ): |
| cls_name = config.pop('type') |
| python_path = config.pop('python_path', None) |
| cls_name = load_class_from_string(cls_name, python_path) if isinstance( |
| cls_name, str) else cls_name |
| AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name) |
| self.agent_actor = AsyncAgentActor.remote(**config) |
|
|
| async def __call__(self, *message: AgentMessage, session_id=0, **kwargs): |
| response = await self.agent_actor.__call__.remote( |
| *message, session_id=session_id, **kwargs) |
| return response |
|
|
|
|
| class AgentRayActor: |
|
|
| def __init__( |
| self, |
| config: Dict, |
| num_gpus: int, |
| ): |
| cls_name = config.pop('type') |
| python_path = config.pop('python_path', None) |
| cls_name = load_class_from_string(cls_name, python_path) if isinstance( |
| cls_name, str) else cls_name |
| AgentActor = ray.remote(num_gpus=num_gpus)(cls_name) |
| self.agent_actor = AgentActor.remote(**config) |
|
|
| def __call__(self, *message: AgentMessage, session_id=0, **kwargs): |
| response = self.agent_actor.__call__.remote( |
| *message, session_id=session_id, **kwargs) |
| return ray.get(response) |
|
|