| import argparse | |
| import json | |
| import logging | |
| import time | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.requests import Request | |
| from lagent.schema import AgentMessage | |
| from lagent.utils import load_class_from_string | |
| class AgentAPIServer: | |
| def __init__(self, | |
| config: dict, | |
| host: str = '127.0.0.1', | |
| port: int = 8090): | |
| self.app = FastAPI(docs_url='/') | |
| self.app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*'], | |
| ) | |
| cls_name = config.pop('type') | |
| python_path = config.pop('python_path', None) | |
| cls_name = load_class_from_string(cls_name, python_path) if isinstance( | |
| cls_name, str) else cls_name | |
| self.agent = cls_name(**config) | |
| self.setup_routes() | |
| self.run(host, port) | |
| def setup_routes(self): | |
| def heartbeat(): | |
| return {'status': 'success', 'timestamp': time.time()} | |
| async def process_message(request: Request): | |
| try: | |
| body = await request.json() | |
| message = [ | |
| m if isinstance(m, str) else AgentMessage.model_validate(m) | |
| for m in body.pop('message') | |
| ] | |
| result = await self.agent(*message, **body) | |
| return result | |
| except Exception as e: | |
| logging.error(f'Error processing message: {str(e)}') | |
| raise HTTPException( | |
| status_code=500, detail='Internal Server Error') | |
| def get_memory(session_id: int = 0): | |
| try: | |
| result = self.agent.state_dict(session_id) | |
| return result | |
| except KeyError: | |
| raise HTTPException( | |
| status_code=404, detail="Session ID not found") | |
| except Exception as e: | |
| logging.error(f'Error processing message: {str(e)}') | |
| raise HTTPException( | |
| status_code=500, detail='Internal Server Error') | |
| self.app.add_api_route('/health_check', heartbeat, methods=['GET']) | |
| self.app.add_api_route( | |
| '/chat_completion', process_message, methods=['POST']) | |
| self.app.add_api_route( | |
| '/memory/{session_id}', get_memory, methods=['GET']) | |
| def run(self, host='127.0.0.1', port=8090): | |
| logging.info(f'Starting server at {host}:{port}') | |
| uvicorn.run(self.app, host=host, port=port) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Async Agent API Server') | |
| parser.add_argument('--host', type=str, default='127.0.0.1') | |
| parser.add_argument('--port', type=int, default=8090) | |
| parser.add_argument( | |
| '--config', | |
| type=json.loads, | |
| required=True, | |
| help='JSON configuration for the agent') | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| logging.basicConfig(level=logging.INFO) | |
| args = parse_args() | |
| AgentAPIServer(args.config, host=args.host, port=args.port) | |