| |
| |
| from __future__ import annotations |
|
|
| import asyncio |
| import contextlib |
| import pathlib |
| import shutil |
| import traceback |
| import uuid |
| from collections import deque |
| from datetime import datetime |
| from enum import Enum |
| from functools import partial |
| from typing import Any, Optional |
|
|
| import fire |
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.responses import StreamingResponse |
| from fastapi.staticfiles import StaticFiles |
| from loguru import logger |
| from metagpt.actions.action import Action |
| from metagpt.actions.action_output import ActionOutput |
| from metagpt.config import CONFIG |
| from metagpt.logs import set_llm_stream_logfunc |
| from metagpt.schema import Message |
| from pydantic import BaseModel, Field |
|
|
| from software_company import RoleRun, SoftwareCompany |
|
|
|
|
| class QueryAnswerType(Enum): |
| Query = "Q" |
| Answer = "A" |
|
|
|
|
| class SentenceType(Enum): |
| TEXT = "text" |
| HIHT = "hint" |
| ACTION = "action" |
| ERROR = "error" |
|
|
|
|
| class MessageStatus(Enum): |
| COMPLETE = "complete" |
|
|
|
|
| class SentenceValue(BaseModel): |
| answer: str |
|
|
|
|
| class Sentence(BaseModel): |
| type: str |
| id: Optional[str] = None |
| value: SentenceValue |
| is_finished: Optional[bool] = None |
|
|
|
|
| class Sentences(BaseModel): |
| id: Optional[str] = None |
| action: Optional[str] = None |
| role: Optional[str] = None |
| skill: Optional[str] = None |
| description: Optional[str] = None |
| timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) |
| status: str |
| contents: list[dict] |
|
|
|
|
| class NewMsg(BaseModel): |
| """Chat with MetaGPT""" |
|
|
| query: str = Field(description="Problem description") |
| config: dict[str, Any] = Field(description="Configuration information") |
|
|
|
|
| class ErrorInfo(BaseModel): |
| error: str = None |
| traceback: str = None |
|
|
|
|
| class ThinkActStep(BaseModel): |
| id: str |
| status: str |
| title: str |
| timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) |
| description: str |
| content: Sentence = None |
|
|
|
|
| class ThinkActPrompt(BaseModel): |
| message_id: int = None |
| timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) |
| step: ThinkActStep = None |
| skill: Optional[str] = None |
| role: Optional[str] = None |
|
|
| def update_think(self, tc_id, action: Action): |
| self.step = ThinkActStep( |
| id=str(tc_id), |
| status="running", |
| title=action.desc, |
| description=action.desc, |
| ) |
|
|
| def update_act(self, message: ActionOutput | str, is_finished: bool = True): |
| if is_finished: |
| self.step.status = "finish" |
| self.step.content = Sentence( |
| type=SentenceType.TEXT.value, |
| id=str(1), |
| value=SentenceValue(answer=message.content if is_finished else message), |
| is_finished=is_finished, |
| ) |
|
|
| @staticmethod |
| def guid32(): |
| return str(uuid.uuid4()).replace("-", "")[0:32] |
|
|
| @property |
| def prompt(self): |
| return self.json(exclude_unset=True) |
|
|
|
|
| class MessageJsonModel(BaseModel): |
| steps: list[Sentences] |
| qa_type: str |
| created_at: datetime = Field(default_factory=datetime.now) |
| query_time: datetime = Field(default_factory=datetime.now) |
| answer_time: datetime = Field(default_factory=datetime.now) |
| score: Optional[int] = None |
| feedback: Optional[str] = None |
|
|
| def add_think_act(self, think_act_prompt: ThinkActPrompt): |
| s = Sentences( |
| action=think_act_prompt.step.title, |
| skill=think_act_prompt.skill, |
| description=think_act_prompt.step.description, |
| timestamp=think_act_prompt.timestamp, |
| status=think_act_prompt.step.status, |
| contents=[think_act_prompt.step.content.dict()], |
| ) |
| self.steps.append(s) |
|
|
| @property |
| def prompt(self): |
| return self.json(exclude_unset=True) |
|
|
|
|
| async def create_message(req_model: NewMsg, request: Request): |
| """ |
| Session message stream |
| """ |
| tc_id = 0 |
| try: |
| exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) |
| config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} |
| set_context(config, uuid.uuid4().hex) |
|
|
| msg_queue = deque() |
| CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None |
|
|
| role = SoftwareCompany() |
| role.recv(message=Message(content=req_model.query)) |
| answer = MessageJsonModel( |
| steps=[ |
| Sentences( |
| contents=[ |
| Sentence( |
| type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True |
| ) |
| ], |
| status=MessageStatus.COMPLETE.value, |
| ) |
| ], |
| qa_type=QueryAnswerType.Answer.value, |
| ) |
|
|
| task = None |
|
|
| async def stop_if_disconnect(): |
| while not await request.is_disconnected(): |
| await asyncio.sleep(1) |
|
|
| if task is None: |
| return |
|
|
| if not task.done(): |
| task.cancel() |
| logger.info(f"cancel task {task}") |
|
|
| asyncio.create_task(stop_if_disconnect()) |
|
|
| while True: |
| tc_id += 1 |
| if await request.is_disconnected(): |
| return |
| think_result: RoleRun = await role.think() |
| if not think_result: |
| break |
|
|
| think_act_prompt = ThinkActPrompt(role=think_result.role.profile) |
| think_act_prompt.update_think(tc_id, think_result) |
| yield think_act_prompt.prompt + "\n\n" |
| task = asyncio.create_task(role.act()) |
|
|
| while not await request.is_disconnected(): |
| if msg_queue: |
| think_act_prompt.update_act(msg_queue.pop(), False) |
| yield think_act_prompt.prompt + "\n\n" |
| continue |
|
|
| if task.done(): |
| break |
|
|
| await asyncio.sleep(0.5) |
| else: |
| task.cancel() |
| return |
|
|
| act_result = await task |
| think_act_prompt.update_act(act_result) |
| yield think_act_prompt.prompt + "\n\n" |
| answer.add_think_act(think_act_prompt) |
| yield answer.prompt + "\n\n" |
| except asyncio.CancelledError: |
| task.cancel() |
| except Exception as ex: |
| description = str(ex) |
| answer = traceback.format_exc() |
| step = ThinkActStep( |
| id=tc_id, |
| status="failed", |
| title=description, |
| description=description, |
| content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), |
| ) |
| think_act_prompt = ThinkActPrompt(step=step) |
| yield think_act_prompt.prompt + "\n\n" |
| finally: |
| CONFIG.WORKSPACE_PATH: pathlib.Path |
| if CONFIG.WORKSPACE_PATH.exists(): |
| shutil.rmtree(CONFIG.WORKSPACE_PATH) |
|
|
|
|
| default_llm_stream_log = partial(print, end="") |
|
|
|
|
| def llm_stream_log(msg): |
| with contextlib.suppress(): |
| CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) |
|
|
|
|
| def set_context(context, uid): |
| context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) |
| for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): |
| if old in context and new not in context: |
| context[new] = context[old] |
| CONFIG.set_context(context) |
| return context |
|
|
|
|
| class ChatHandler: |
| @staticmethod |
| async def create_message(req_model: NewMsg, request: Request): |
| """Message stream, using SSE.""" |
| event = create_message(req_model, request) |
| headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} |
| return StreamingResponse(event, headers=headers, media_type="text/event-stream") |
|
|
|
|
| app = FastAPI() |
|
|
| app.mount( |
| "/storage", |
| StaticFiles(directory="./storage/"), |
| name="storage", |
| ) |
|
|
| app.add_api_route( |
| "/api/messages", |
| endpoint=ChatHandler.create_message, |
| methods=["post"], |
| summary="Session message sending (streaming response)", |
| ) |
|
|
|
|
| app.mount( |
| "/", |
| StaticFiles(directory="./static/", html=True, follow_symlink=True), |
| name="static", |
| ) |
|
|
|
|
| set_llm_stream_logfunc(llm_stream_log) |
|
|
|
|
| def main(): |
| server_config = CONFIG.get("SERVER_UVICORN", {}) |
| uvicorn.run(app="__main__:app", **server_config) |
|
|
|
|
| if __name__ == "__main__": |
| fire.Fire(main) |
|
|