| import argparse |
| import json |
| import logging |
| import time |
|
|
| import uvicorn |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.requests import Request |
|
|
| from lagent.schema import AgentMessage |
| from lagent.utils import load_class_from_string |
|
|
|
|
| class AgentAPIServer: |
|
|
| def __init__(self, |
| config: dict, |
| host: str = '127.0.0.1', |
| port: int = 8090): |
| self.app = FastAPI(docs_url='/') |
| self.app.add_middleware( |
| CORSMiddleware, |
| allow_origins=['*'], |
| allow_credentials=True, |
| allow_methods=['*'], |
| allow_headers=['*'], |
| ) |
| cls_name = config.pop('type') |
| python_path = config.pop('python_path', None) |
| cls_name = load_class_from_string(cls_name, python_path) if isinstance( |
| cls_name, str) else cls_name |
| self.agent = cls_name(**config) |
| self.setup_routes() |
| self.run(host, port) |
|
|
| def setup_routes(self): |
|
|
| def heartbeat(): |
| return {'status': 'success', 'timestamp': time.time()} |
|
|
| async def process_message(request: Request): |
| try: |
| body = await request.json() |
| message = [ |
| m if isinstance(m, str) else AgentMessage.model_validate(m) |
| for m in body.pop('message') |
| ] |
| result = await self.agent(*message, **body) |
| return result |
| except Exception as e: |
| logging.error(f'Error processing message: {str(e)}') |
| raise HTTPException( |
| status_code=500, detail='Internal Server Error') |
|
|
| def get_memory(session_id: int = 0): |
| try: |
| result = self.agent.state_dict(session_id) |
| return result |
| except KeyError: |
| raise HTTPException( |
| status_code=404, detail="Session ID not found") |
| except Exception as e: |
| logging.error(f'Error processing message: {str(e)}') |
| raise HTTPException( |
| status_code=500, detail='Internal Server Error') |
|
|
| self.app.add_api_route('/health_check', heartbeat, methods=['GET']) |
| self.app.add_api_route( |
| '/chat_completion', process_message, methods=['POST']) |
| self.app.add_api_route( |
| '/memory/{session_id}', get_memory, methods=['GET']) |
|
|
| def run(self, host='127.0.0.1', port=8090): |
| logging.info(f'Starting server at {host}:{port}') |
| uvicorn.run(self.app, host=host, port=port) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Async Agent API Server') |
| parser.add_argument('--host', type=str, default='127.0.0.1') |
| parser.add_argument('--port', type=int, default=8090) |
| parser.add_argument( |
| '--config', |
| type=json.loads, |
| required=True, |
| help='JSON configuration for the agent') |
| args = parser.parse_args() |
|
|
| return args |
|
|
|
|
| if __name__ == '__main__': |
| logging.basicConfig(level=logging.INFO) |
| args = parse_args() |
| AgentAPIServer(args.config, host=args.host, port=args.port) |
|
|