Spaces:
Runtime error
Runtime error
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from __future__ import annotations | |
| import asyncio | |
| from collections import deque | |
| import contextlib | |
| from functools import partial | |
| import shutil | |
| import urllib.parse | |
| from datetime import datetime | |
| import uuid | |
| from enum import Enum | |
| from metagpt.logs import set_llm_stream_logfunc | |
| import pathlib | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import StreamingResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import fire | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| from typing import Any, Optional | |
| from metagpt.schema import Message | |
| from metagpt.actions.action import Action | |
| from metagpt.actions.action_output import ActionOutput | |
| from metagpt.config import CONFIG | |
| from software_company import RoleRun, SoftwareCompany | |
| class QueryAnswerType(Enum): | |
| Query = "Q" | |
| Answer = "A" | |
| class SentenceType(Enum): | |
| TEXT = "text" | |
| HIHT = "hint" | |
| ACTION = "action" | |
| class MessageStatus(Enum): | |
| COMPLETE = "complete" | |
| class SentenceValue(BaseModel): | |
| answer: str | |
| class Sentence(BaseModel): | |
| type: str | |
| id: Optional[str] = None | |
| value: SentenceValue | |
| is_finished: Optional[bool] = None | |
| class Sentences(BaseModel): | |
| id: Optional[str] = None | |
| action: Optional[str] = None | |
| role: Optional[str] = None | |
| skill: Optional[str] = None | |
| description: Optional[str] = None | |
| timestamp: str = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z") | |
| status: str | |
| contents: list[dict] | |
| class NewMsg(BaseModel): | |
| """Chat with MetaGPT""" | |
| query: str = Field(description="Problem description") | |
| config: dict[str, Any] = Field(description="Configuration information") | |
| class ErrorInfo(BaseModel): | |
| error: str = None | |
| traceback: str = None | |
| class ThinkActStep(BaseModel): | |
| id: str | |
| status: str | |
| title: str | |
| timestamp: str | |
| description: str | |
| content: Sentence = None | |
| class ThinkActPrompt(BaseModel): | |
| message_id: int = None | |
| timestamp: str = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z") | |
| step: ThinkActStep = None | |
| skill: Optional[str] = None | |
| role: Optional[str] = None | |
| def update_think(self, tc_id, action: Action): | |
| self.step = ThinkActStep( | |
| id=str(tc_id), | |
| status="running", | |
| title=action.desc, | |
| timestamp=datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z"), | |
| description=action.desc, | |
| ) | |
| def update_act(self, message: ActionOutput | str, is_finished: bool = True): | |
| if is_finished: | |
| self.step.status = "finish" | |
| self.step.content = Sentence( | |
| type="text", | |
| id=str(1), | |
| value=SentenceValue(answer=message.content if is_finished else message), | |
| is_finished=is_finished, | |
| ) | |
| def guid32(): | |
| return str(uuid.uuid4()).replace("-", "")[0:32] | |
| def prompt(self): | |
| v = self.json(exclude_unset=True) | |
| return urllib.parse.quote(v) | |
| class MessageJsonModel(BaseModel): | |
| steps: list[Sentences] | |
| qa_type: str | |
| created_at: datetime = datetime.now() | |
| query_time: datetime = datetime.now() | |
| answer_time: datetime = datetime.now() | |
| score: Optional[int] = None | |
| feedback: Optional[str] = None | |
| def add_think_act(self, think_act_prompt: ThinkActPrompt): | |
| s = Sentences( | |
| action=think_act_prompt.step.title, | |
| skill=think_act_prompt.skill, | |
| description=think_act_prompt.step.description, | |
| timestamp=think_act_prompt.timestamp, | |
| status=think_act_prompt.step.status, | |
| contents=[think_act_prompt.step.content.dict()], | |
| ) | |
| self.steps.append(s) | |
| def prompt(self): | |
| v = self.json(exclude_unset=True) | |
| return urllib.parse.quote(v) | |
| async def create_message(req_model: NewMsg, request: Request): | |
| """ | |
| Session message stream | |
| """ | |
| try: | |
| config = {k.upper(): v for k, v in req_model.config.items()} | |
| set_context(config, uuid.uuid4().hex) | |
| msg_queue = deque() | |
| CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None | |
| role = SoftwareCompany() | |
| role.recv(message=Message(content=req_model.query)) | |
| answer = MessageJsonModel( | |
| steps=[ | |
| Sentences( | |
| contents=[ | |
| Sentence(type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True) | |
| ], | |
| status=MessageStatus.COMPLETE.value, | |
| ) | |
| ], | |
| qa_type=QueryAnswerType.Answer.value, | |
| ) | |
| tc_id = 0 | |
| while True: | |
| tc_id += 1 | |
| if request and await request.is_disconnected(): | |
| return | |
| think_result: RoleRun = await role.think() | |
| if not think_result: # End of conversion | |
| break | |
| think_act_prompt = ThinkActPrompt(role=think_result.role.profile) | |
| think_act_prompt.update_think(tc_id, think_result) | |
| yield think_act_prompt.prompt + "\n\n" | |
| task = asyncio.create_task(role.act()) | |
| while not await request.is_disconnected(): | |
| if msg_queue: | |
| think_act_prompt.update_act(msg_queue.pop(), False) | |
| yield think_act_prompt.prompt + "\n\n" | |
| continue | |
| if task.done(): | |
| break | |
| await asyncio.sleep(0.5) | |
| act_result = await task | |
| think_act_prompt.update_act(act_result) | |
| yield think_act_prompt.prompt + "\n\n" | |
| answer.add_think_act(think_act_prompt) | |
| yield answer.prompt + "\n\n" # Notify the front-end that the message is complete. | |
| finally: | |
| shutil.rmtree(CONFIG.WORKSPACE_PATH) | |
| default_llm_stream_log = partial(print, end="") | |
| def llm_stream_log(msg): | |
| with contextlib.suppress(): | |
| CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) | |
| def set_context(context, uid): | |
| context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) | |
| for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): | |
| if old in context and new not in context: | |
| context[new] = context[old] | |
| CONFIG.set_context(context) | |
| return context | |
| class ChatHandler: | |
| async def create_message(req_model: NewMsg, request: Request): | |
| """Message stream, using SSE.""" | |
| event = create_message(req_model, request) | |
| headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} | |
| return StreamingResponse(event, headers=headers, media_type="text/event-stream") | |
| app = FastAPI() | |
| app.mount( | |
| "/storage", | |
| StaticFiles(directory="./storage/"), | |
| name="static", | |
| ) | |
| app.add_api_route( | |
| "/api/messages", | |
| endpoint=ChatHandler.create_message, | |
| methods=["post"], | |
| summary="Session message sending (streaming response)", | |
| ) | |
| app.mount( | |
| "/", | |
| StaticFiles(directory="./src/", html=True), | |
| name="src", | |
| ) | |
| set_llm_stream_logfunc(llm_stream_log) | |
| def main(): | |
| uvicorn.run(app="__main__:app", host="0.0.0.0", port=7860) | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |