| |
| |
| from __future__ import annotations |
|
|
| import asyncio |
| import contextlib |
| import os |
| import pathlib |
| import re |
| import shutil |
| import time |
| import traceback |
| import uuid |
| from collections import deque |
| from contextlib import asynccontextmanager |
| from functools import partial |
| from typing import Dict |
|
|
| import fire |
| import tenacity |
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse, StreamingResponse |
| from fastapi.staticfiles import StaticFiles |
| from loguru import logger |
| from metagpt.config import CONFIG |
| from metagpt.logs import set_llm_stream_logfunc |
| from metagpt.schema import Message |
| from metagpt.utils.common import any_to_name, any_to_str |
| from openai import OpenAI |
|
|
| from data_model import ( |
| LLMAPIkeyTest, |
| MessageJsonModel, |
| NewMsg, |
| Sentence, |
| Sentences, |
| SentenceType, |
| SentenceValue, |
| ThinkActPrompt, |
| ThinkActStep, |
| ) |
| from message_enum import MessageStatus, QueryAnswerType |
| from software_company import RoleRun, SoftwareCompany |
|
|
|
|
| class Service: |
| @classmethod |
| async def create_message(cls, req_model: NewMsg, request: Request): |
| """ |
| Session message stream |
| """ |
| tc_id = 0 |
| task = None |
| try: |
| exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) |
| config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} |
| cls._set_context(config) |
| msg_queue = deque() |
| CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None |
|
|
| role = SoftwareCompany() |
| role.recv(message=Message(content=req_model.query)) |
| answer = MessageJsonModel( |
| steps=[ |
| Sentences( |
| contents=[ |
| Sentence( |
| type=SentenceType.TEXT.value, |
| value=SentenceValue(answer=req_model.query), |
| is_finished=True, |
| ).model_dump() |
| ], |
| status=MessageStatus.COMPLETE.value, |
| ) |
| ], |
| qa_type=QueryAnswerType.Answer.value, |
| ) |
|
|
| async def stop_if_disconnect(): |
| while not await request.is_disconnected(): |
| await asyncio.sleep(1) |
|
|
| if task is None: |
| return |
|
|
| if not task.done(): |
| task.cancel() |
| logger.info(f"cancel task {task}") |
|
|
| asyncio.create_task(stop_if_disconnect()) |
|
|
| while True: |
| tc_id += 1 |
| if await request.is_disconnected(): |
| return |
| think_result: RoleRun = await role.think() |
| if not think_result: |
| break |
|
|
| think_act_prompt = ThinkActPrompt(role=think_result.role.profile) |
| think_act_prompt.update_think(tc_id, think_result) |
| yield think_act_prompt.prompt + "\n\n" |
| task = asyncio.create_task(role.act()) |
|
|
| while not await request.is_disconnected(): |
| if msg_queue: |
| think_act_prompt.update_act(msg_queue.pop(), False) |
| yield think_act_prompt.prompt + "\n\n" |
| continue |
|
|
| if task.done(): |
| break |
|
|
| await asyncio.sleep(0.5) |
| else: |
| task.cancel() |
| return |
|
|
| act_result = await task |
| think_act_prompt.update_act(act_result) |
| yield think_act_prompt.prompt + "\n\n" |
| answer.add_think_act(think_act_prompt) |
| yield answer.prompt + "\n\n" |
| except asyncio.CancelledError: |
| task.cancel() |
| except tenacity.RetryError as retry_error: |
| yield cls.handle_retry_error(tc_id, retry_error) |
| except Exception as ex: |
| description = str(ex) |
| answer = traceback.format_exc() |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) |
| yield think_act_prompt.prompt + "\n\n" |
| finally: |
| CONFIG.WORKSPACE_PATH: pathlib.Path |
| if CONFIG.WORKSPACE_PATH.exists(): |
| shutil.rmtree(CONFIG.WORKSPACE_PATH) |
|
|
| @staticmethod |
| def create_error_think_act_prompt(tc_id: int, title, description: str, answer: str) -> ThinkActPrompt: |
| step = ThinkActStep( |
| id=tc_id, |
| status="failed", |
| title=title, |
| description=description, |
| content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), |
| ) |
| return ThinkActPrompt(step=step) |
|
|
| @classmethod |
| def handle_retry_error(cls, tc_id: int, error: tenacity.RetryError): |
| |
| try: |
| |
| original_exception = error.last_attempt.exception() |
| while isinstance(original_exception, tenacity.RetryError): |
| original_exception = original_exception.last_attempt.exception() |
|
|
| name = any_to_str(original_exception) |
| if re.match(r"^openai\.", name): |
| return cls._handle_openai_error(tc_id, original_exception) |
|
|
| if re.match(r"^httpx\.", name): |
| return cls._handle_httpx_error(tc_id, original_exception) |
|
|
| if re.match(r"^json\.", name): |
| return cls._handle_json_error(tc_id, original_exception) |
|
|
| return cls.handle_unexpected_error(tc_id, error) |
| except Exception: |
| return cls.handle_unexpected_error(tc_id, error) |
|
|
| @classmethod |
| def _handle_openai_error(cls, tc_id, original_exception): |
| answer = original_exception.message |
| title = f"OpenAI {any_to_name(original_exception)}" |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) |
| return think_act_prompt.prompt + "\n\n" |
|
|
| @classmethod |
| def _handle_httpx_error(cls, tc_id, original_exception): |
| answer = f"{original_exception}. {original_exception.request}" |
| title = f"httpx {any_to_name(original_exception)}" |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) |
| return think_act_prompt.prompt + "\n\n" |
|
|
| @classmethod |
| def _handle_json_error(cls, tc_id, original_exception): |
| answer = str(original_exception) |
| title = "MetaGPT Action Node Error" |
| description = f"LLM response parse error. {any_to_str(original_exception)}: {original_exception}" |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, description, answer) |
| return think_act_prompt.prompt + "\n\n" |
|
|
| @classmethod |
| def handle_unexpected_error(cls, tc_id, error): |
| description = str(error) |
| answer = traceback.format_exc() |
| think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) |
| return think_act_prompt.prompt + "\n\n" |
|
|
| @staticmethod |
| def _set_context(context: Dict) -> Dict: |
| uid = uuid.uuid4().hex |
| context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) |
| for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): |
| if old in context and new not in context: |
| context[new] = context[old] |
| CONFIG.set_context(context) |
| return context |
|
|
|
|
| default_llm_stream_log = partial(print, end="") |
|
|
|
|
| def llm_stream_log(msg): |
| with contextlib.suppress(): |
| CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) |
|
|
|
|
| class ChatHandler: |
| @staticmethod |
| async def create_message(req_model: NewMsg, request: Request): |
| """Message stream, using SSE.""" |
| event = Service.create_message(req_model, request) |
| headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} |
| return StreamingResponse(event, headers=headers, media_type="text/event-stream") |
|
|
|
|
| class LLMAPIHandler: |
| @staticmethod |
| async def check_openai_key(req_model: LLMAPIkeyTest): |
| try: |
| |
| client = OpenAI(api_key=req_model.api_key) |
| response = client.models.list() |
| model_set = {model.id for model in response.data} |
| if req_model.llm_type in model_set: |
| logger.info("API Key is valid.") |
| return JSONResponse({"valid": True}) |
| else: |
| logger.info("API Key is invalid.") |
| return JSONResponse({"valid": False, "message": "Model not found"}) |
| except Exception as e: |
| |
| logger.info(f"Error: {e}") |
| return JSONResponse({"valid": False, "message": str(e)}) |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| loop = asyncio.get_running_loop() |
| loop.create_task(clear_storage()) |
| yield |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| app.mount( |
| "/storage", |
| StaticFiles(directory="./storage/"), |
| name="storage", |
| ) |
|
|
| app.add_api_route( |
| "/api/messages", |
| endpoint=ChatHandler.create_message, |
| methods=["post"], |
| summary="Session message sending (streaming response)", |
| ) |
| app.add_api_route( |
| "/api/test-api-key", |
| endpoint=LLMAPIHandler.check_openai_key, |
| methods=["post"], |
| summary="LLM APIkey detection", |
| ) |
|
|
| app.mount( |
| "/", |
| StaticFiles(directory="./static/", html=True, follow_symlink=True), |
| name="static", |
| ) |
|
|
| set_llm_stream_logfunc(llm_stream_log) |
|
|
|
|
| def gen_file_modified_time(folder_path): |
| yield os.path.getmtime(folder_path) |
| for root, _, files in os.walk(folder_path): |
| for file in files: |
| file_path = os.path.join(root, file) |
| yield os.path.getmtime(file_path) |
|
|
|
|
| async def clear_storage(ttl: float = 1800): |
| storage = pathlib.Path(CONFIG.get("LOCAL_ROOT", "storage")) |
| logger.info("task `clear_storage` start running") |
|
|
| while True: |
| current_time = time.time() |
| for i in os.listdir(storage): |
| i = storage / i |
| try: |
| last_time = max(gen_file_modified_time(i)) |
| if current_time - last_time > ttl: |
| shutil.rmtree(i) |
| await asyncio.sleep(0) |
| logger.info(f"Deleted directory: {i}") |
| except Exception: |
| logger.exception(f"check {i} error") |
| await asyncio.sleep(60) |
|
|
|
|
| def main(): |
| server_config = CONFIG.get("SERVER_UVICORN", {}) |
| uvicorn.run(app="__main__:app", **server_config) |
|
|
|
|
| if __name__ == "__main__": |
| fire.Fire(main) |
|
|