| | import json |
| | import os |
| | import subprocess |
| | import sys |
| | import time |
| |
|
| | import aiohttp |
| | import requests |
| |
|
| | from lagent.schema import AgentMessage |
| |
|
| |
|
| | class HTTPAgentClient: |
| |
|
| | def __init__(self, host='127.0.0.1', port=8090, timeout=None): |
| | self.host = host |
| | self.port = port |
| | self.timeout = timeout |
| |
|
| | @property |
| | def is_alive(self): |
| | try: |
| | resp = requests.get( |
| | f'http://{self.host}:{self.port}/health_check', |
| | timeout=self.timeout) |
| | return resp.status_code == 200 |
| | except: |
| | return False |
| |
|
| | def __call__(self, *message, session_id: int = 0, **kwargs): |
| | response = requests.post( |
| | f'http://{self.host}:{self.port}/chat_completion', |
| | json={ |
| | 'message': [ |
| | m if isinstance(m, str) else m.model_dump() |
| | for m in message |
| | ], |
| | 'session_id': session_id, |
| | **kwargs, |
| | }, |
| | headers={'Content-Type': 'application/json'}, |
| | timeout=self.timeout) |
| | resp = response.json() |
| | if response.status_code != 200: |
| | return resp |
| | return AgentMessage.model_validate(resp) |
| |
|
| | def state_dict(self, session_id: int = 0): |
| | resp = requests.get( |
| | f'http://{self.host}:{self.port}/memory/{session_id}', |
| | timeout=self.timeout) |
| | return resp.json() |
| |
|
| |
|
| | class HTTPAgentServer(HTTPAgentClient): |
| |
|
| | def __init__(self, gpu_id, config, host='127.0.0.1', port=8090): |
| | super().__init__(host, port) |
| | self.gpu_id = gpu_id |
| | self.config = config |
| | self.start_server() |
| |
|
| | def start_server(self): |
| | |
| | env = os.environ.copy() |
| | env['CUDA_VISIBLE_DEVICES'] = self.gpu_id |
| | cmds = [ |
| | sys.executable, 'lagent/distributed/http_serve/app.py', '--host', |
| | self.host, '--port', |
| | str(self.port), '--config', |
| | json.dumps(self.config) |
| | ] |
| | self.process = subprocess.Popen( |
| | cmds, |
| | env=env, |
| | stdout=subprocess.PIPE, |
| | stderr=subprocess.STDOUT, |
| | text=True) |
| |
|
| | while True: |
| | output = self.process.stdout.readline() |
| | if not output: |
| | break |
| | sys.stdout.write(output) |
| | sys.stdout.flush() |
| | if 'Uvicorn running on' in output: |
| | break |
| | time.sleep(0.1) |
| |
|
| | def shutdown(self): |
| | self.process.terminate() |
| | self.process.wait() |
| |
|
| |
|
| | class AsyncHTTPAgentMixin: |
| |
|
| | async def __call__(self, *message, session_id: int = 0, **kwargs): |
| | async with aiohttp.ClientSession( |
| | timeout=aiohttp.ClientTimeout(self.timeout)) as session: |
| | async with session.post( |
| | f'http://{self.host}:{self.port}/chat_completion', |
| | json={ |
| | 'message': [ |
| | m if isinstance(m, str) else m.model_dump() |
| | for m in message |
| | ], |
| | 'session_id': session_id, |
| | **kwargs, |
| | }, |
| | headers={'Content-Type': 'application/json'}, |
| | ) as response: |
| | resp = await response.json() |
| | if response.status != 200: |
| | return resp |
| | return AgentMessage.model_validate(resp) |
| |
|
| |
|
| | class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient): |
| | pass |
| |
|
| |
|
| | class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer): |
| | pass |
| |
|