Upload 80 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +131 -0
- Dockerfile +3 -0
- activities/activity_helpers.py +33 -0
- activities/api.py +93 -0
- activities/complete_chat.py +77 -0
- activities/eval.py +207 -0
- activities/local_demo.py +108 -0
- activities/local_test.py +87 -0
- activities/predict.py +41 -0
- activities/vllm_api_server.py +636 -0
- configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml +23 -0
- configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml +23 -0
- configs/agent_configs/react_agent_azureopenai_gpt_4_async_dcoker.yaml +23 -0
- configs/agent_configs/react_agent_gpt4_async.yaml +23 -0
- configs/agent_configs/react_agent_llama_async.yaml +23 -0
- configs/agent_configs/react_agent_opt_async.yaml +23 -0
- configs/tool_configs/async_python_code_sandbox.yaml +7 -0
- configs/tool_configs/async_python_code_sandbox_docker.yaml +7 -0
- run.sh +3 -0
- run_demo.sh +5 -0
- run_local.sh +4 -0
- setup.py +40 -0
- src/infiagent/__init__.py +0 -0
- src/infiagent/agent/__init__.py +2 -0
- src/infiagent/agent/base_agent.py +337 -0
- src/infiagent/agent/react/__init__.py +4 -0
- src/infiagent/agent/react/async_react_agent.py +299 -0
- src/infiagent/conversation_sessions/__init__.py +1 -0
- src/infiagent/conversation_sessions/code_interpreter_session.py +87 -0
- src/infiagent/exceptions/__init__.py +0 -0
- src/infiagent/exceptions/exceptions.py +46 -0
- src/infiagent/llm/__init__.py +5 -0
- src/infiagent/llm/base_llm.py +36 -0
- src/infiagent/llm/client/__init__.py +0 -0
- src/infiagent/llm/client/azure_openai.py +346 -0
- src/infiagent/llm/client/llama.py +377 -0
- src/infiagent/llm/client/openai.py +306 -0
- src/infiagent/llm/client/opt.py +373 -0
- src/infiagent/prompt/__init__.py +3 -0
- src/infiagent/prompt/prompt_template.py +83 -0
- src/infiagent/prompt/simple_react_prompt.py +17 -0
- src/infiagent/prompt/zero_shot_react_prompt.py +36 -0
- src/infiagent/schemas/__init__.py +5 -0
- src/infiagent/schemas/agent_models.py +148 -0
- src/infiagent/schemas/base_models.py +0 -0
- src/infiagent/schemas/complete_models.py +236 -0
- src/infiagent/schemas/llm_models.py +91 -0
- src/infiagent/schemas/sandbox_models.py +69 -0
- src/infiagent/services/__init__.py +0 -0
- src/infiagent/services/chat_complete_service.py +196 -0
.gitignore
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
|
| 28 |
+
# PyInstaller
|
| 29 |
+
# Usually these files are written by a python script from a template
|
| 30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 31 |
+
*.manifest
|
| 32 |
+
*.spec
|
| 33 |
+
|
| 34 |
+
# Installer logs
|
| 35 |
+
pip-log.txt
|
| 36 |
+
pip-delete-this-directory.txt
|
| 37 |
+
|
| 38 |
+
# Unit test / coverage reports
|
| 39 |
+
htmlcov/
|
| 40 |
+
.tox/
|
| 41 |
+
.nox/
|
| 42 |
+
.coverage
|
| 43 |
+
.coverage.*
|
| 44 |
+
.cache
|
| 45 |
+
nosetests.xml
|
| 46 |
+
coverage.xml
|
| 47 |
+
*.cover
|
| 48 |
+
*.py,cover
|
| 49 |
+
.hypothesis/
|
| 50 |
+
.pytest_cache/
|
| 51 |
+
|
| 52 |
+
# Translations
|
| 53 |
+
*.mo
|
| 54 |
+
*.pot
|
| 55 |
+
|
| 56 |
+
# Django stuff:
|
| 57 |
+
*.log
|
| 58 |
+
local_settings.py
|
| 59 |
+
db.sqlite3
|
| 60 |
+
db.sqlite3-journal
|
| 61 |
+
|
| 62 |
+
# Flask stuff:
|
| 63 |
+
instance/
|
| 64 |
+
.webassets-cache
|
| 65 |
+
|
| 66 |
+
# Scrapy stuff:
|
| 67 |
+
.scrapy
|
| 68 |
+
|
| 69 |
+
# Sphinx documentation
|
| 70 |
+
docs/_build/
|
| 71 |
+
build/doctrees
|
| 72 |
+
build/html
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# pyenv
|
| 81 |
+
.python-version
|
| 82 |
+
|
| 83 |
+
# celery beat schedule file
|
| 84 |
+
celerybeat-schedule
|
| 85 |
+
|
| 86 |
+
# SageMath parsed files
|
| 87 |
+
*.sage.py
|
| 88 |
+
|
| 89 |
+
# Environments
|
| 90 |
+
.env
|
| 91 |
+
.venv
|
| 92 |
+
env/
|
| 93 |
+
venv/
|
| 94 |
+
ENV/
|
| 95 |
+
env.bak/
|
| 96 |
+
venv.bak/
|
| 97 |
+
|
| 98 |
+
# Spyder project settings
|
| 99 |
+
.spyderproject
|
| 100 |
+
.spyproject
|
| 101 |
+
|
| 102 |
+
# Rope project settings
|
| 103 |
+
.ropeproject
|
| 104 |
+
|
| 105 |
+
# mkdocs documentation
|
| 106 |
+
/site
|
| 107 |
+
|
| 108 |
+
# mypy
|
| 109 |
+
.mypy_cache/
|
| 110 |
+
.dmypy.json
|
| 111 |
+
dmypy.json
|
| 112 |
+
|
| 113 |
+
# Pyre type checker
|
| 114 |
+
.pyre/
|
| 115 |
+
|
| 116 |
+
# pytype static type analyzer
|
| 117 |
+
.pytype/
|
| 118 |
+
|
| 119 |
+
# Cython debug symbols
|
| 120 |
+
cython_debug/
|
| 121 |
+
|
| 122 |
+
# JetBrains PyCharm specific
|
| 123 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, GoLand, Rider and Android Studio
|
| 124 |
+
.idea/
|
| 125 |
+
*.iml
|
| 126 |
+
|
| 127 |
+
# User-specific stuff
|
| 128 |
+
*.swp
|
| 129 |
+
*~
|
| 130 |
+
.Session.vim
|
| 131 |
+
/.sass-cache
|
Dockerfile
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3
|
| 2 |
+
|
| 3 |
+
RUN pip install pandas numpy scikit-learn matplotlib seaborn
|
activities/activity_helpers.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
from sse_starlette import ServerSentEvent
|
| 4 |
+
|
| 5 |
+
from infiagent.schemas import ResponseBaseData
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
IGNORE_PING_COMMENT = {"comment": "IGNORE PING"}
|
| 9 |
+
DONE = "[DONE]"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def async_sse_response_format(response_data_gen):
|
| 13 |
+
async for content in response_data_gen:
|
| 14 |
+
if content == DONE:
|
| 15 |
+
sse_event = ServerSentEvent(data=DONE)
|
| 16 |
+
else:
|
| 17 |
+
data_dict = {
|
| 18 |
+
"response": content,
|
| 19 |
+
"ResponseBase": ResponseBaseData().dict()
|
| 20 |
+
}
|
| 21 |
+
sse_event = ServerSentEvent(data=json.dumps(data_dict, ensure_ascii=False))
|
| 22 |
+
yield sse_event
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def json_response_format(content):
|
| 26 |
+
return {
|
| 27 |
+
"response": content,
|
| 28 |
+
"ResponseBase": ResponseBaseData().dict()
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_ignore_ping_comment():
|
| 33 |
+
return lambda: ServerSentEvent(**IGNORE_PING_COMMENT)
|
activities/api.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import uuid
|
| 3 |
+
|
| 4 |
+
import uvloop
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
| 9 |
+
from starlette.responses import JSONResponse, Response
|
| 10 |
+
|
| 11 |
+
from .activity_helpers import DONE
|
| 12 |
+
from .complete_chat import complete_chat_router
|
| 13 |
+
from .predict import predict_router
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import infiagent
|
| 17 |
+
from infiagent.schemas import FailedResponseBaseData
|
| 18 |
+
from infiagent.utils import get_logger, init_logging, log_id_var
|
| 19 |
+
except ImportError:
|
| 20 |
+
print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
|
| 21 |
+
from ..schemas import FailedResponseBaseData
|
| 22 |
+
from ..utils import get_logger, init_logging, log_id_var
|
| 23 |
+
|
| 24 |
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
| 25 |
+
|
| 26 |
+
SSE_API_PATHS = ["/complete_sse"]
|
| 27 |
+
LOG_ID_HEADER_NAME = "X-Tt-Logid"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
load_dotenv()
|
| 31 |
+
init_logging()
|
| 32 |
+
logger = get_logger()
|
| 33 |
+
|
| 34 |
+
app = FastAPI()
|
| 35 |
+
origins = ["*"]
|
| 36 |
+
app.add_middleware(
|
| 37 |
+
CORSMiddleware,
|
| 38 |
+
allow_origins=origins,
|
| 39 |
+
allow_credentials=True,
|
| 40 |
+
allow_methods=["*"],
|
| 41 |
+
allow_headers=["*"],
|
| 42 |
+
)
|
| 43 |
+
app.include_router(complete_chat_router)
|
| 44 |
+
app.include_router(predict_router)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@app.middleware("http")
|
| 48 |
+
async def log_id_middleware(request: Request, call_next):
|
| 49 |
+
# Get X-Tt-Logid from request headers
|
| 50 |
+
log_id = request.headers.get(LOG_ID_HEADER_NAME)
|
| 51 |
+
if not log_id:
|
| 52 |
+
# Generate a log_id if not present in headers
|
| 53 |
+
log_id = str(uuid.uuid4())
|
| 54 |
+
|
| 55 |
+
log_id_var.set(log_id)
|
| 56 |
+
|
| 57 |
+
response: Response = await call_next(request)
|
| 58 |
+
response.headers[LOG_ID_HEADER_NAME] = log_id_var.get()
|
| 59 |
+
return response
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@app.exception_handler(Exception)
|
| 63 |
+
async def general_exception_handler(request, exc):
|
| 64 |
+
error_msg = "Failed to handle request. Internal Server error: {}".format(str(exc))
|
| 65 |
+
logger.error(error_msg, exc_info=True)
|
| 66 |
+
|
| 67 |
+
if request.url.path in SSE_API_PATHS:
|
| 68 |
+
return EventSourceResponse(ServerSentEvent(data=DONE))
|
| 69 |
+
else:
|
| 70 |
+
return JSONResponse(
|
| 71 |
+
status_code=500,
|
| 72 |
+
content={
|
| 73 |
+
"response": error_msg,
|
| 74 |
+
"ResponseBase": FailedResponseBaseData().dict()
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@app.exception_handler(HTTPException)
|
| 80 |
+
async def http_exception_handler(request, exc):
|
| 81 |
+
error_msg = "Failed to handle request. Error: {}".format(exc.detail)
|
| 82 |
+
logger.error(error_msg, exc_info=True)
|
| 83 |
+
|
| 84 |
+
if request.url.path in SSE_API_PATHS:
|
| 85 |
+
return EventSourceResponse(ServerSentEvent(data=DONE))
|
| 86 |
+
else:
|
| 87 |
+
return JSONResponse(
|
| 88 |
+
status_code=exc.status_code,
|
| 89 |
+
content={
|
| 90 |
+
"response": error_msg,
|
| 91 |
+
"ResponseBase": FailedResponseBaseData().dict()
|
| 92 |
+
}
|
| 93 |
+
)
|
activities/complete_chat.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Request, HTTPException
|
| 2 |
+
from pydantic import ValidationError
|
| 3 |
+
from sse_starlette import EventSourceResponse, ServerSentEvent
|
| 4 |
+
|
| 5 |
+
from .activity_helpers import async_sse_response_format, IGNORE_PING_COMMENT, json_response_format
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import infiagent
|
| 9 |
+
from infiagent.db.conversation_dao import ConversationDAO
|
| 10 |
+
from infiagent.schemas import ChatCompleteRequest
|
| 11 |
+
from infiagent.services.chat_complete_sse_service import chat_event_generator, chat_event_response
|
| 12 |
+
from infiagent.tools.code_sandbox.async_sandbox_client import AsyncSandboxClient
|
| 13 |
+
from infiagent.utils import get_logger
|
| 14 |
+
except ImportError:
|
| 15 |
+
print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
|
| 16 |
+
from ..db.conversation_dao import ConversationDAO
|
| 17 |
+
from ..schemas import ChatCompleteRequest
|
| 18 |
+
from ..services.chat_complete_sse_service import chat_event_generator, chat_event_response
|
| 19 |
+
from ..tools.code_sandbox.async_sandbox_client import AsyncSandboxClient
|
| 20 |
+
from ..utils import get_logger
|
| 21 |
+
|
| 22 |
+
complete_chat_router = APIRouter()
|
| 23 |
+
logger = get_logger()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@complete_chat_router.post("/complete_sse")
|
| 27 |
+
async def complete_sse(request: Request):
|
| 28 |
+
body_str = await request.body()
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
chat_request = ChatCompleteRequest.parse_raw(body_str)
|
| 32 |
+
logger.info("Got chat request: {}".format(chat_request))
|
| 33 |
+
except ValidationError as e:
|
| 34 |
+
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
|
| 35 |
+
raise HTTPException(status_code=400, detail=error_msg)
|
| 36 |
+
|
| 37 |
+
return EventSourceResponse(async_sse_response_format(chat_event_generator(chat_request)),
|
| 38 |
+
ping_message_factory=lambda: ServerSentEvent(**IGNORE_PING_COMMENT))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@complete_chat_router.post("/complete")
|
| 42 |
+
async def complete(request: Request):
|
| 43 |
+
body_str = await request.body()
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
chat_request = ChatCompleteRequest.parse_raw(body_str)
|
| 47 |
+
logger.info("Got chat request: {}".format(chat_request))
|
| 48 |
+
except ValidationError as e:
|
| 49 |
+
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
|
| 50 |
+
raise HTTPException(status_code=400, detail=error_msg)
|
| 51 |
+
|
| 52 |
+
response_items = await chat_event_response(chat_request)
|
| 53 |
+
|
| 54 |
+
return json_response_format(response_items)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@complete_chat_router.get("/heartbeat")
|
| 58 |
+
async def heartbeat(chat_id: str = None, session_id: str = None):
|
| 59 |
+
if not chat_id and not session_id:
|
| 60 |
+
raise HTTPException(status_code=400, detail="Either chat_id or session_id must be provided.")
|
| 61 |
+
|
| 62 |
+
input_chat_id = chat_id or session_id
|
| 63 |
+
|
| 64 |
+
conversation = await ConversationDAO.get_conversation(input_chat_id)
|
| 65 |
+
if not conversation:
|
| 66 |
+
logger.info(f'Call heartbeat on a non-exist conversion, {input_chat_id}')
|
| 67 |
+
return json_response_format("conversation is not created, skip")
|
| 68 |
+
|
| 69 |
+
if conversation.sandbox_id is None:
|
| 70 |
+
logger.error(f'No sandbox id for heartbeat, chat id {input_chat_id}')
|
| 71 |
+
raise HTTPException(status_code=404, detail=f'No sandbox id for heartbeat, chat id {input_chat_id}')
|
| 72 |
+
|
| 73 |
+
# TODO Add exception handling logic here for heartbeat failed in sandbox side
|
| 74 |
+
heartbeat_response = await AsyncSandboxClient(conversation.sandbox_id).heartbeat()
|
| 75 |
+
logger.info(f"Heartbeat response {heartbeat_response}")
|
| 76 |
+
|
| 77 |
+
return json_response_format("succeed")
|
activities/eval.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import argparse
|
| 4 |
+
import asyncio
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
import io
|
| 9 |
+
|
| 10 |
+
import openai
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import infiagent
|
| 14 |
+
from infiagent.utils import get_logger, upload_files, get_file_name_and_path
|
| 15 |
+
from infiagent.services.chat_complete_service import predict
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class UploadedFile(io.BytesIO):
|
| 22 |
+
def __init__(self, path):
|
| 23 |
+
with open(path, 'rb') as file:
|
| 24 |
+
data = file.read()
|
| 25 |
+
|
| 26 |
+
super().__init__(data)
|
| 27 |
+
|
| 28 |
+
self.name = path.split("/")[-1] # 获取文件名
|
| 29 |
+
self.type = 'application/octet-stream' # 或者其他适当的 MIME 类型
|
| 30 |
+
self.size = len(data)
|
| 31 |
+
|
| 32 |
+
def __repr__(self):
|
| 33 |
+
return f"MyUploadedFile(name={self.name}, size={self.size}, type={self.type})"
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
|
| 37 |
+
return self.size
|
| 38 |
+
|
| 39 |
+
# # 使用例子
|
| 40 |
+
# file_path = "path/to/your/file"
|
| 41 |
+
# uploaded_file = MyUploadedFile(file_path)
|
| 42 |
+
|
| 43 |
+
# print(uploaded_file)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _get_script_params():
|
| 47 |
+
try:
|
| 48 |
+
parser = argparse.ArgumentParser()
|
| 49 |
+
parser.add_argument('--llm',
|
| 50 |
+
help='LLM Model for demo',
|
| 51 |
+
required=False, type=str)
|
| 52 |
+
parser.add_argument('--api_key',
|
| 53 |
+
help='Open API token key.',
|
| 54 |
+
required=False, type=str)
|
| 55 |
+
|
| 56 |
+
parser.add_argument('--config_path',
|
| 57 |
+
help='Config path for demo',
|
| 58 |
+
default="configs/agent_configs/react_agent_llama_async.yaml",
|
| 59 |
+
required=False, type=str)
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
return args
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error("Failed to get script input arguments: {}".format(str(e)), exc_info=True)
|
| 66 |
+
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def extract_questions_and_concepts(file_path):
|
| 71 |
+
# Read the content of the text file
|
| 72 |
+
with open(file_path, 'r') as file:
|
| 73 |
+
content = file.read()
|
| 74 |
+
|
| 75 |
+
# Use regular expressions to extract questions and concepts
|
| 76 |
+
pattern = r'\\Question{(.*?)}\s*\\Concepts{(.*?)}'
|
| 77 |
+
matches = re.findall(pattern, content, re.DOTALL)
|
| 78 |
+
|
| 79 |
+
# Build a list of dictionaries containing the questions and concepts
|
| 80 |
+
data = []
|
| 81 |
+
for match in matches:
|
| 82 |
+
question = match[0].strip()
|
| 83 |
+
concepts = [concept.strip() for concept in match[1].split(',')]
|
| 84 |
+
data.append({
|
| 85 |
+
'question': question,
|
| 86 |
+
'concepts': concepts
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
return data
|
| 90 |
+
|
| 91 |
+
def read_dicts_from_file(file_name):
|
| 92 |
+
"""
|
| 93 |
+
Read a file with each line containing a JSON string representing a dictionary,
|
| 94 |
+
and return a list of dictionaries.
|
| 95 |
+
|
| 96 |
+
:param file_name: Name of the file to read from.
|
| 97 |
+
:return: List of dictionaries.
|
| 98 |
+
"""
|
| 99 |
+
dict_list = []
|
| 100 |
+
with open(file_name, 'r') as file:
|
| 101 |
+
for line in file:
|
| 102 |
+
# Convert the JSON string back to a dictionary.
|
| 103 |
+
dictionary = json.loads(line.rstrip('\n'))
|
| 104 |
+
dict_list.append(dictionary)
|
| 105 |
+
return dict_list
|
| 106 |
+
|
| 107 |
+
def read_questions(file_path):
|
| 108 |
+
print(file_path)
|
| 109 |
+
with open(file_path) as f:
|
| 110 |
+
questions = json.load(f)
|
| 111 |
+
|
| 112 |
+
return questions
|
| 113 |
+
|
| 114 |
+
def extract_data_from_folder(folder_path):
|
| 115 |
+
|
| 116 |
+
print(f'folder_path {folder_path}')
|
| 117 |
+
extracted_data = {}
|
| 118 |
+
# Traverse the files in the folder
|
| 119 |
+
for file_name in os.listdir(folder_path):
|
| 120 |
+
if file_name.endswith('.questions'): # You can filter files based on their type
|
| 121 |
+
file_path = os.path.join(folder_path, file_name)
|
| 122 |
+
file_data = read_questions(file_path)
|
| 123 |
+
file_name_without_extension = os.path.splitext(file_name)[0]
|
| 124 |
+
extracted_data[file_name_without_extension] = file_data
|
| 125 |
+
|
| 126 |
+
return extracted_data
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
async def main():
|
| 130 |
+
extracted_data = read_dicts_from_file('./data/da-dev-questions.jsonl')
|
| 131 |
+
args = _get_script_params()
|
| 132 |
+
|
| 133 |
+
model_name = getattr(args, "llm", None)
|
| 134 |
+
open_ai_key = getattr(args, "api_key", None)
|
| 135 |
+
|
| 136 |
+
if "OPEN_AI" in model_name:
|
| 137 |
+
logger.info("setup open ai ")
|
| 138 |
+
if os.environ.get("OPENAI_API_KEY") is None:
|
| 139 |
+
if open_ai_key:
|
| 140 |
+
openai.api_key = open_ai_key
|
| 141 |
+
os.environ["OPENAI_API_KEY"] = open_ai_key
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError("OPENAI_API_KEY is None, please provide open ai key to use open ai model. Adding "
|
| 144 |
+
"'--api_key' to set it up")
|
| 145 |
+
|
| 146 |
+
# 获取 'openai' 的 logger
|
| 147 |
+
openai_logger = logging.getLogger('openai')
|
| 148 |
+
# 设置日志级别为 'WARNING',这样 'INFO' 级别的日志就不会被打印了
|
| 149 |
+
openai_logger.setLevel(logging.WARNING)
|
| 150 |
+
else:
|
| 151 |
+
logger.info("use local model ")
|
| 152 |
+
|
| 153 |
+
table_path = 'data/da-dev-tables'
|
| 154 |
+
results = []
|
| 155 |
+
|
| 156 |
+
i = 1
|
| 157 |
+
for q in extracted_data:
|
| 158 |
+
input_text = q['question']
|
| 159 |
+
concepts = q['concepts']
|
| 160 |
+
file_path = q['file_name']
|
| 161 |
+
constraints = q['constraints']
|
| 162 |
+
format = q['format']
|
| 163 |
+
|
| 164 |
+
file_path = os.path.join(table_path, file_path)
|
| 165 |
+
|
| 166 |
+
print(f'input_text: {input_text}')
|
| 167 |
+
print(f'concepts: {concepts}')
|
| 168 |
+
print(f'file_path: {file_path}')
|
| 169 |
+
|
| 170 |
+
uploaded_file = UploadedFile(file_path)
|
| 171 |
+
print(uploaded_file)
|
| 172 |
+
|
| 173 |
+
prompt = f"Question: {input_text}\n{constraints}\n"
|
| 174 |
+
|
| 175 |
+
response = await predict(
|
| 176 |
+
prompt=prompt,
|
| 177 |
+
model_name=model_name,
|
| 178 |
+
config_path=args.config_path,
|
| 179 |
+
uploaded_files=[uploaded_file]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
iteration_result = {
|
| 183 |
+
'id': q['id'],
|
| 184 |
+
'input_text': prompt,
|
| 185 |
+
'concepts': concepts,
|
| 186 |
+
'file_path': file_path,
|
| 187 |
+
'response': response,
|
| 188 |
+
'format': format
|
| 189 |
+
}
|
| 190 |
+
results.append(iteration_result)
|
| 191 |
+
print(f"response: {response}")
|
| 192 |
+
|
| 193 |
+
if i % 10 == 0:
|
| 194 |
+
with open('results_{}.json'.format(model_name), 'w') as outfile:
|
| 195 |
+
json.dump(results, outfile, indent=4)
|
| 196 |
+
|
| 197 |
+
i += 1
|
| 198 |
+
|
| 199 |
+
with open('results_{}.json'.format(model_name), 'w') as outfile:
|
| 200 |
+
json.dump(results, outfile, indent=4)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == '__main__':
|
| 204 |
+
asyncio.run(main())
|
| 205 |
+
# main()
|
| 206 |
+
|
| 207 |
+
|
activities/local_demo.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import asyncio
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
import streamlit as st # type: ignore
|
| 8 |
+
import uvloop
|
| 9 |
+
import openai
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import infiagent
|
| 13 |
+
from infiagent.utils import get_logger, upload_files
|
| 14 |
+
from infiagent.services.chat_complete_service import predict
|
| 15 |
+
except ImportError:
|
| 16 |
+
raise (
|
| 17 |
+
"import infiagent failed, please install infiagent by 'pip install -e .' in the pipeline directory of ADA-Agent")
|
| 18 |
+
|
| 19 |
+
logger = get_logger()
|
| 20 |
+
|
| 21 |
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_script_params():
|
| 25 |
+
try:
|
| 26 |
+
parser = argparse.ArgumentParser()
|
| 27 |
+
parser.add_argument('--llm',
|
| 28 |
+
help='LLM Model for demo',
|
| 29 |
+
required=False, type=str)
|
| 30 |
+
parser.add_argument('--api_key',
|
| 31 |
+
help='Open API token key.',
|
| 32 |
+
required=False, type=str)
|
| 33 |
+
parser.add_argument('--config_path',
|
| 34 |
+
help='Config path for demo',
|
| 35 |
+
# default="configs/agent_configs/react_agent_gpt4_async.yaml",
|
| 36 |
+
required=False, type=str)
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
return args
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error("Failed to get script input arguments: {}".format(str(e)), exc_info=True)
|
| 43 |
+
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
async def main():
|
| 48 |
+
args = _get_script_params()
|
| 49 |
+
|
| 50 |
+
model_name = getattr(args, "llm", None)
|
| 51 |
+
open_ai_key = getattr(args, "api_key", None)
|
| 52 |
+
config_path = getattr(args, "config_path", None)
|
| 53 |
+
|
| 54 |
+
if "OPEN_AI" in model_name:
|
| 55 |
+
logger.info("setup open ai ")
|
| 56 |
+
if os.environ.get("OPENAI_API_KEY") is None:
|
| 57 |
+
if open_ai_key:
|
| 58 |
+
openai.api_key = open_ai_key
|
| 59 |
+
os.environ["OPENAI_API_KEY"] = open_ai_key
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
"OPENAI_API_KEY is None, please provide opekn ai key to use open ai model. Adding '--api_key' to set it up")
|
| 63 |
+
|
| 64 |
+
# 获取 'openai' 的 logger
|
| 65 |
+
openai_logger = logging.getLogger('openai')
|
| 66 |
+
# 设置日志级别为 'WARNING',这样 'INFO' 级别的日志就不会被打印了
|
| 67 |
+
openai_logger.setLevel(logging.WARNING)
|
| 68 |
+
|
| 69 |
+
else:
|
| 70 |
+
logger.info("use local model ")
|
| 71 |
+
|
| 72 |
+
st.set_page_config(layout="centered")
|
| 73 |
+
|
| 74 |
+
st.title("InfiAgent Code Interpreter Demo 🚀")
|
| 75 |
+
|
| 76 |
+
# Initialize session state variables if not already present
|
| 77 |
+
if 'chat_history' not in st.session_state:
|
| 78 |
+
st.session_state.chat_history = []
|
| 79 |
+
|
| 80 |
+
# UI components
|
| 81 |
+
input_text = st.text_area("Write your prompt")
|
| 82 |
+
uploaded_files = st.file_uploader("Upload your files", accept_multiple_files=True)
|
| 83 |
+
button_pressed = st.button("Run code interpreter", use_container_width=True)
|
| 84 |
+
|
| 85 |
+
# When button is pressed
|
| 86 |
+
if button_pressed and input_text != "":
|
| 87 |
+
# Add user message to chat history
|
| 88 |
+
st.session_state.chat_history.append({"role": "user", "message": input_text})
|
| 89 |
+
|
| 90 |
+
# Predict response (assuming you have the necessary async handling)
|
| 91 |
+
response = await predict(
|
| 92 |
+
prompt=input_text,
|
| 93 |
+
model_name=model_name,
|
| 94 |
+
config_path=config_path,
|
| 95 |
+
uploaded_files=uploaded_files,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Add assistant message to chat history
|
| 99 |
+
st.session_state.chat_history.append({"role": "assistant", "message": response})
|
| 100 |
+
|
| 101 |
+
# Display chat history
|
| 102 |
+
for chat in st.session_state.chat_history:
|
| 103 |
+
with st.chat_message(chat["role"]):
|
| 104 |
+
st.write(chat["message"])
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
asyncio.run(main())
|
activities/local_test.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 4 |
+
from pydantic import ValidationError
|
| 5 |
+
from sse_starlette import EventSourceResponse
|
| 6 |
+
|
| 7 |
+
from .activity_helpers import (
|
| 8 |
+
async_sse_response_format,
|
| 9 |
+
get_ignore_ping_comment,
|
| 10 |
+
json_response_format,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import infiagent
|
| 16 |
+
from infiagent.schemas import ChatCompleteRequest
|
| 17 |
+
from infiagent.services.complete_local_test import (
|
| 18 |
+
chat_local_event,
|
| 19 |
+
chat_local_event_generator,
|
| 20 |
+
)
|
| 21 |
+
from infiagent.utils import get_logger
|
| 22 |
+
except ImportError:
|
| 23 |
+
print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
|
| 24 |
+
from ..schemas import ChatCompleteRequest
|
| 25 |
+
from ..services.complete_local_test import (
|
| 26 |
+
chat_local_event,
|
| 27 |
+
chat_local_event_generator,
|
| 28 |
+
)
|
| 29 |
+
from ..utils import get_logger
|
| 30 |
+
|
| 31 |
+
logger = get_logger()
|
| 32 |
+
local_app = FastAPI()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@local_app.post("/local_sse_test")
|
| 36 |
+
async def complete_sse(request: Request):
|
| 37 |
+
body_str = await request.body()
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
chat_request = ChatCompleteRequest.parse_raw(body_str)
|
| 41 |
+
logger.info("Got chat request: {}".format(chat_request))
|
| 42 |
+
except ValidationError as e:
|
| 43 |
+
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
|
| 44 |
+
raise HTTPException(status_code=500, detail=error_msg)
|
| 45 |
+
|
| 46 |
+
return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)),
|
| 47 |
+
ping_message_factory=get_ignore_ping_comment())
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@local_app.post("/local_json_test")
|
| 51 |
+
async def complete_json(request: Request):
|
| 52 |
+
|
| 53 |
+
body_str = await request.body()
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
chat_request = ChatCompleteRequest.parse_raw(body_str)
|
| 57 |
+
logger.info("Got chat request: {}".format(chat_request))
|
| 58 |
+
except ValidationError as e:
|
| 59 |
+
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
|
| 60 |
+
raise HTTPException(status_code=500, detail=error_msg)
|
| 61 |
+
|
| 62 |
+
response_items = await chat_local_event(chat_request)
|
| 63 |
+
return json_response_format(response_items)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@local_app.post("/exception_test")
|
| 67 |
+
async def complete_json(request: Request):
|
| 68 |
+
body_str = await request.body()
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
chat_request = ChatCompleteRequest.parse_raw(body_str)
|
| 72 |
+
logger.info("Got chat request: {}".format(chat_request))
|
| 73 |
+
except ValidationError as e:
|
| 74 |
+
error_msg = "Invalid input chat_request. Error: {}".format(str(e))
|
| 75 |
+
raise HTTPException(status_code=500, detail=error_msg)
|
| 76 |
+
return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
async def exception_test(request: Request):
|
| 80 |
+
body_str = await request.body()
|
| 81 |
+
json_val = json.loads(body_str)
|
| 82 |
+
exception_type = json_val.get("exception", None)
|
| 83 |
+
|
| 84 |
+
if exception_type:
|
| 85 |
+
raise ValueError("Error triggerd!")
|
| 86 |
+
else:
|
| 87 |
+
yield iter(["Success"])
|
activities/predict.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, File, Form, UploadFile
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import infiagent
|
| 6 |
+
from infiagent.services.chat_complete_service import predict
|
| 7 |
+
except ImportError:
|
| 8 |
+
print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
|
| 9 |
+
from ..services.chat_complete_service import predict
|
| 10 |
+
|
| 11 |
+
predict_router = APIRouter()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@predict_router.post("/predict")
|
| 15 |
+
async def chat_predict(
|
| 16 |
+
prompt: str = Form(...),
|
| 17 |
+
model_name: str = Form(...),
|
| 18 |
+
psm: Optional[str] = Form(None),
|
| 19 |
+
dc: Optional[str] = Form(None),
|
| 20 |
+
temperature: Optional[str] = Form(None),
|
| 21 |
+
top_p: Optional[str] = Form(None),
|
| 22 |
+
top_k: Optional[str] = Form(None),
|
| 23 |
+
files: List[UploadFile] = File(...)
|
| 24 |
+
):
|
| 25 |
+
kwargs = {}
|
| 26 |
+
if psm:
|
| 27 |
+
kwargs['psm'] = psm
|
| 28 |
+
if dc:
|
| 29 |
+
kwargs['dc'] = dc
|
| 30 |
+
if temperature:
|
| 31 |
+
kwargs['temperature'] = float(temperature)
|
| 32 |
+
if top_p:
|
| 33 |
+
kwargs['top_p'] = float(top_p)
|
| 34 |
+
if top_k:
|
| 35 |
+
kwargs['top_k'] = float(top_k)
|
| 36 |
+
|
| 37 |
+
response = await predict(prompt, model_name, files, **kwargs)
|
| 38 |
+
|
| 39 |
+
return {
|
| 40 |
+
"answer": response
|
| 41 |
+
}
|
activities/vllm_api_server.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from
|
| 2 |
+
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
from http import HTTPStatus
|
| 9 |
+
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import fastapi
|
| 12 |
+
import uvicorn
|
| 13 |
+
from fastapi import Request
|
| 14 |
+
from fastapi.exceptions import RequestValidationError
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
| 17 |
+
from packaging import version
|
| 18 |
+
|
| 19 |
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
| 20 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
| 21 |
+
from vllm.entrypoints.openai.protocol import (
|
| 22 |
+
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
| 23 |
+
CompletionResponseStreamChoice, CompletionStreamResponse,
|
| 24 |
+
ChatCompletionRequest, ChatCompletionResponse,
|
| 25 |
+
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
| 26 |
+
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
| 27 |
+
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
| 28 |
+
from vllm.logger import init_logger
|
| 29 |
+
from vllm.outputs import RequestOutput
|
| 30 |
+
from vllm.sampling_params import SamplingParams
|
| 31 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
| 32 |
+
from vllm.utils import random_uuid
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import fastchat
|
| 36 |
+
from fastchat.conversation import Conversation, SeparatorStyle
|
| 37 |
+
from fastchat.model.model_adapter import get_conversation_template
|
| 38 |
+
_fastchat_available = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
_fastchat_available = False
|
| 41 |
+
|
| 42 |
+
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
| 43 |
+
|
| 44 |
+
logger = init_logger(__name__)
|
| 45 |
+
served_model = None
|
| 46 |
+
app = fastapi.FastAPI()
|
| 47 |
+
engine = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def create_error_response(status_code: HTTPStatus,
|
| 51 |
+
message: str) -> JSONResponse:
|
| 52 |
+
return JSONResponse(ErrorResponse(message=message,
|
| 53 |
+
type="invalid_request_error").dict(),
|
| 54 |
+
status_code=status_code.value)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.exception_handler(RequestValidationError)
|
| 58 |
+
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
|
| 59 |
+
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
async def check_model(request) -> Optional[JSONResponse]:
|
| 63 |
+
if request.model == served_model:
|
| 64 |
+
return
|
| 65 |
+
ret = create_error_response(
|
| 66 |
+
HTTPStatus.NOT_FOUND,
|
| 67 |
+
f"The model `{request.model}` does not exist.",
|
| 68 |
+
)
|
| 69 |
+
return ret
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def get_gen_prompt(request) -> str:
|
| 73 |
+
if not _fastchat_available:
|
| 74 |
+
raise ModuleNotFoundError(
|
| 75 |
+
"fastchat is not installed. Please install fastchat to use "
|
| 76 |
+
"the chat completion and conversation APIs: `$ pip install fschat`"
|
| 77 |
+
)
|
| 78 |
+
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
| 79 |
+
raise ImportError(
|
| 80 |
+
f"fastchat version is low. Current version: {fastchat.__version__} "
|
| 81 |
+
"Please upgrade fastchat to use: `$ pip install -U fschat`")
|
| 82 |
+
|
| 83 |
+
conv = get_conversation_template(request.model)
|
| 84 |
+
conv = Conversation(
|
| 85 |
+
name=conv.name,
|
| 86 |
+
system_template=conv.system_template,
|
| 87 |
+
system_message=conv.system_message,
|
| 88 |
+
roles=conv.roles,
|
| 89 |
+
messages=list(conv.messages), # prevent in-place modification
|
| 90 |
+
offset=conv.offset,
|
| 91 |
+
sep_style=SeparatorStyle(conv.sep_style),
|
| 92 |
+
sep=conv.sep,
|
| 93 |
+
sep2=conv.sep2,
|
| 94 |
+
stop_str=conv.stop_str,
|
| 95 |
+
stop_token_ids=conv.stop_token_ids,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if isinstance(request.messages, str):
|
| 99 |
+
prompt = request.messages
|
| 100 |
+
else:
|
| 101 |
+
for message in request.messages:
|
| 102 |
+
msg_role = message["role"]
|
| 103 |
+
if msg_role == "system":
|
| 104 |
+
conv.system_message = message["content"]
|
| 105 |
+
elif msg_role == "user":
|
| 106 |
+
conv.append_message(conv.roles[0], message["content"])
|
| 107 |
+
elif msg_role == "assistant":
|
| 108 |
+
conv.append_message(conv.roles[1], message["content"])
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown role: {msg_role}")
|
| 111 |
+
|
| 112 |
+
# Add a blank message for the assistant.
|
| 113 |
+
conv.append_message(conv.roles[1], None)
|
| 114 |
+
prompt = conv.get_prompt()
|
| 115 |
+
|
| 116 |
+
return prompt
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
async def check_length(
|
| 120 |
+
request: Union[ChatCompletionRequest, CompletionRequest],
|
| 121 |
+
prompt: Optional[str] = None,
|
| 122 |
+
prompt_ids: Optional[List[int]] = None
|
| 123 |
+
) -> Tuple[List[int], Optional[JSONResponse]]:
|
| 124 |
+
assert (not (prompt is None and prompt_ids is None)
|
| 125 |
+
and not (prompt is not None and prompt_ids is not None)
|
| 126 |
+
), "Either prompt or prompt_ids should be provided."
|
| 127 |
+
if prompt_ids is not None:
|
| 128 |
+
input_ids = prompt_ids
|
| 129 |
+
else:
|
| 130 |
+
input_ids = tokenizer(prompt).input_ids
|
| 131 |
+
token_num = len(input_ids)
|
| 132 |
+
|
| 133 |
+
if request.max_tokens is None:
|
| 134 |
+
request.max_tokens = max_model_len - token_num
|
| 135 |
+
if token_num + request.max_tokens > max_model_len:
|
| 136 |
+
return input_ids, create_error_response(
|
| 137 |
+
HTTPStatus.BAD_REQUEST,
|
| 138 |
+
f"This model's maximum context length is {max_model_len} tokens. "
|
| 139 |
+
f"However, you requested {request.max_tokens + token_num} tokens "
|
| 140 |
+
f"({token_num} in the messages, "
|
| 141 |
+
f"{request.max_tokens} in the completion). "
|
| 142 |
+
f"Please reduce the length of the messages or completion.",
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
return input_ids, None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@app.get("/health")
|
| 149 |
+
async def health() -> Response:
|
| 150 |
+
"""Health check."""
|
| 151 |
+
return Response(status_code=200)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@app.get("/v1/models")
|
| 155 |
+
async def show_available_models():
|
| 156 |
+
"""Show available models. Right now we only have one model."""
|
| 157 |
+
model_cards = [
|
| 158 |
+
ModelCard(id=served_model,
|
| 159 |
+
root=served_model,
|
| 160 |
+
permission=[ModelPermission()])
|
| 161 |
+
]
|
| 162 |
+
return ModelList(data=model_cards)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def create_logprobs(token_ids: List[int],
|
| 166 |
+
id_logprobs: List[Dict[int, float]],
|
| 167 |
+
initial_text_offset: int = 0) -> LogProbs:
|
| 168 |
+
"""Create OpenAI-style logprobs."""
|
| 169 |
+
logprobs = LogProbs()
|
| 170 |
+
last_token_len = 0
|
| 171 |
+
for token_id, id_logprob in zip(token_ids, id_logprobs):
|
| 172 |
+
token = tokenizer.convert_ids_to_tokens(token_id)
|
| 173 |
+
logprobs.tokens.append(token)
|
| 174 |
+
logprobs.token_logprobs.append(id_logprob[token_id])
|
| 175 |
+
if len(logprobs.text_offset) == 0:
|
| 176 |
+
logprobs.text_offset.append(initial_text_offset)
|
| 177 |
+
else:
|
| 178 |
+
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
| 179 |
+
last_token_len)
|
| 180 |
+
last_token_len = len(token)
|
| 181 |
+
|
| 182 |
+
logprobs.top_logprobs.append({
|
| 183 |
+
tokenizer.convert_ids_to_tokens(i): p
|
| 184 |
+
for i, p in id_logprob.items()
|
| 185 |
+
})
|
| 186 |
+
return logprobs
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@app.post("/v1/chat/completions")
|
| 190 |
+
async def create_chat_completion(request: ChatCompletionRequest,
|
| 191 |
+
raw_request: Request):
|
| 192 |
+
"""Completion API similar to OpenAI's API.
|
| 193 |
+
|
| 194 |
+
See https://platform.openai.com/docs/api-reference/chat/create
|
| 195 |
+
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
| 196 |
+
|
| 197 |
+
NOTE: Currently we do not support the following features:
|
| 198 |
+
- function_call (Users should implement this by themselves)
|
| 199 |
+
- logit_bias (to be supported by vLLM engine)
|
| 200 |
+
"""
|
| 201 |
+
logger.info(f"Received chat completion request: {request}")
|
| 202 |
+
|
| 203 |
+
error_check_ret = await check_model(request)
|
| 204 |
+
if error_check_ret is not None:
|
| 205 |
+
return error_check_ret
|
| 206 |
+
|
| 207 |
+
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
| 208 |
+
# TODO: support logit_bias in vLLM engine.
|
| 209 |
+
return create_error_response(HTTPStatus.BAD_REQUEST,
|
| 210 |
+
"logit_bias is not currently supported")
|
| 211 |
+
|
| 212 |
+
prompt = await get_gen_prompt(request)
|
| 213 |
+
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
| 214 |
+
if error_check_ret is not None:
|
| 215 |
+
return error_check_ret
|
| 216 |
+
|
| 217 |
+
model_name = request.model
|
| 218 |
+
request_id = f"cmpl-{random_uuid()}"
|
| 219 |
+
created_time = int(time.monotonic())
|
| 220 |
+
try:
|
| 221 |
+
# spaces_between_special_tokens = request.spaces_between_special_tokens
|
| 222 |
+
sampling_params = SamplingParams(
|
| 223 |
+
n=request.n,
|
| 224 |
+
presence_penalty=request.presence_penalty,
|
| 225 |
+
frequency_penalty=request.frequency_penalty,
|
| 226 |
+
temperature=request.temperature,
|
| 227 |
+
top_p=request.top_p,
|
| 228 |
+
stop=request.stop,
|
| 229 |
+
stop_token_ids=request.stop_token_ids,
|
| 230 |
+
max_tokens=request.max_tokens,
|
| 231 |
+
best_of=request.best_of,
|
| 232 |
+
top_k=request.top_k,
|
| 233 |
+
ignore_eos=request.ignore_eos,
|
| 234 |
+
use_beam_search=request.use_beam_search,
|
| 235 |
+
skip_special_tokens=request.skip_special_tokens,
|
| 236 |
+
# spaces_between_special_tokens=spaces_between_special_tokens,
|
| 237 |
+
)
|
| 238 |
+
except ValueError as e:
|
| 239 |
+
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
| 240 |
+
|
| 241 |
+
result_generator = engine.generate(prompt, sampling_params, request_id,
|
| 242 |
+
token_ids)
|
| 243 |
+
|
| 244 |
+
def create_stream_response_json(
|
| 245 |
+
index: int,
|
| 246 |
+
text: str,
|
| 247 |
+
finish_reason: Optional[str] = None,
|
| 248 |
+
) -> str:
|
| 249 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
| 250 |
+
index=index,
|
| 251 |
+
delta=DeltaMessage(content=text),
|
| 252 |
+
finish_reason=finish_reason,
|
| 253 |
+
)
|
| 254 |
+
response = ChatCompletionStreamResponse(
|
| 255 |
+
id=request_id,
|
| 256 |
+
created=created_time,
|
| 257 |
+
model=model_name,
|
| 258 |
+
choices=[choice_data],
|
| 259 |
+
)
|
| 260 |
+
response_json = response.json(ensure_ascii=False)
|
| 261 |
+
|
| 262 |
+
return response_json
|
| 263 |
+
|
| 264 |
+
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
| 265 |
+
# First chunk with role
|
| 266 |
+
for i in range(request.n):
|
| 267 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
| 268 |
+
index=i,
|
| 269 |
+
delta=DeltaMessage(role="assistant"),
|
| 270 |
+
finish_reason=None,
|
| 271 |
+
)
|
| 272 |
+
chunk = ChatCompletionStreamResponse(id=request_id,
|
| 273 |
+
choices=[choice_data],
|
| 274 |
+
model=model_name)
|
| 275 |
+
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
| 276 |
+
yield f"data: {data}\n\n"
|
| 277 |
+
|
| 278 |
+
previous_texts = [""] * request.n
|
| 279 |
+
previous_num_tokens = [0] * request.n
|
| 280 |
+
async for res in result_generator:
|
| 281 |
+
res: RequestOutput
|
| 282 |
+
for output in res.outputs:
|
| 283 |
+
i = output.index
|
| 284 |
+
delta_text = output.text[len(previous_texts[i]):]
|
| 285 |
+
previous_texts[i] = output.text
|
| 286 |
+
previous_num_tokens[i] = len(output.token_ids)
|
| 287 |
+
response_json = create_stream_response_json(
|
| 288 |
+
index=i,
|
| 289 |
+
text=delta_text,
|
| 290 |
+
)
|
| 291 |
+
yield f"data: {response_json}\n\n"
|
| 292 |
+
if output.finish_reason is not None:
|
| 293 |
+
response_json = create_stream_response_json(
|
| 294 |
+
index=i,
|
| 295 |
+
text="",
|
| 296 |
+
finish_reason=output.finish_reason,
|
| 297 |
+
)
|
| 298 |
+
yield f"data: {response_json}\n\n"
|
| 299 |
+
yield "data: [DONE]\n\n"
|
| 300 |
+
|
| 301 |
+
# Streaming response
|
| 302 |
+
if request.stream:
|
| 303 |
+
return StreamingResponse(completion_stream_generator(),
|
| 304 |
+
media_type="text/event-stream")
|
| 305 |
+
|
| 306 |
+
# Non-streaming response
|
| 307 |
+
final_res: RequestOutput = None
|
| 308 |
+
async for res in result_generator:
|
| 309 |
+
if await raw_request.is_disconnected():
|
| 310 |
+
# Abort the request if the client disconnects.
|
| 311 |
+
await engine.abort(request_id)
|
| 312 |
+
return create_error_response(HTTPStatus.BAD_REQUEST,
|
| 313 |
+
"Client disconnected")
|
| 314 |
+
final_res = res
|
| 315 |
+
assert final_res is not None
|
| 316 |
+
choices = []
|
| 317 |
+
for output in final_res.outputs:
|
| 318 |
+
choice_data = ChatCompletionResponseChoice(
|
| 319 |
+
index=output.index,
|
| 320 |
+
message=ChatMessage(role="assistant", content=output.text),
|
| 321 |
+
finish_reason=output.finish_reason,
|
| 322 |
+
)
|
| 323 |
+
choices.append(choice_data)
|
| 324 |
+
|
| 325 |
+
num_prompt_tokens = len(final_res.prompt_token_ids)
|
| 326 |
+
num_generated_tokens = sum(
|
| 327 |
+
len(output.token_ids) for output in final_res.outputs)
|
| 328 |
+
usage = UsageInfo(
|
| 329 |
+
prompt_tokens=num_prompt_tokens,
|
| 330 |
+
completion_tokens=num_generated_tokens,
|
| 331 |
+
total_tokens=num_prompt_tokens + num_generated_tokens,
|
| 332 |
+
)
|
| 333 |
+
response = ChatCompletionResponse(
|
| 334 |
+
id=request_id,
|
| 335 |
+
created=created_time,
|
| 336 |
+
model=model_name,
|
| 337 |
+
choices=choices,
|
| 338 |
+
usage=usage,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if request.stream:
|
| 342 |
+
# When user requests streaming but we don't stream, we still need to
|
| 343 |
+
# return a streaming response with a single event.
|
| 344 |
+
response_json = response.json(ensure_ascii=False)
|
| 345 |
+
|
| 346 |
+
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
| 347 |
+
yield f"data: {response_json}\n\n"
|
| 348 |
+
yield "data: [DONE]\n\n"
|
| 349 |
+
|
| 350 |
+
return StreamingResponse(fake_stream_generator(),
|
| 351 |
+
media_type="text/event-stream")
|
| 352 |
+
|
| 353 |
+
return response
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@app.post("/v1/completions")
|
| 357 |
+
async def create_completion(request: CompletionRequest, raw_request: Request):
|
| 358 |
+
"""Completion API similar to OpenAI's API.
|
| 359 |
+
|
| 360 |
+
See https://platform.openai.com/docs/api-reference/completions/create
|
| 361 |
+
for the API specification. This API mimics the OpenAI Completion API.
|
| 362 |
+
|
| 363 |
+
NOTE: Currently we do not support the following features:
|
| 364 |
+
- echo (since the vLLM engine does not currently support
|
| 365 |
+
getting the logprobs of prompt tokens)
|
| 366 |
+
- suffix (the language models we currently support do not support
|
| 367 |
+
suffix)
|
| 368 |
+
- logit_bias (to be supported by vLLM engine)
|
| 369 |
+
"""
|
| 370 |
+
logger.info(f"Received completion request: {request}")
|
| 371 |
+
|
| 372 |
+
error_check_ret = await check_model(request)
|
| 373 |
+
if error_check_ret is not None:
|
| 374 |
+
return error_check_ret
|
| 375 |
+
|
| 376 |
+
if request.echo:
|
| 377 |
+
# We do not support echo since the vLLM engine does not
|
| 378 |
+
# currently support getting the logprobs of prompt tokens.
|
| 379 |
+
return create_error_response(HTTPStatus.BAD_REQUEST,
|
| 380 |
+
"echo is not currently supported")
|
| 381 |
+
|
| 382 |
+
if request.suffix is not None:
|
| 383 |
+
# The language models we currently support do not support suffix.
|
| 384 |
+
return create_error_response(HTTPStatus.BAD_REQUEST,
|
| 385 |
+
"suffix is not currently supported")
|
| 386 |
+
|
| 387 |
+
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
| 388 |
+
# TODO: support logit_bias in vLLM engine.
|
| 389 |
+
return create_error_response(HTTPStatus.BAD_REQUEST,
|
| 390 |
+
"logit_bias is not currently supported")
|
| 391 |
+
|
| 392 |
+
model_name = request.model
|
| 393 |
+
request_id = f"cmpl-{random_uuid()}"
|
| 394 |
+
|
| 395 |
+
use_token_ids = False
|
| 396 |
+
if isinstance(request.prompt, list):
|
| 397 |
+
if len(request.prompt) == 0:
|
| 398 |
+
return create_error_response(HTTPStatus.BAD_REQUEST,
|
| 399 |
+
"please provide at least one prompt")
|
| 400 |
+
first_element = request.prompt[0]
|
| 401 |
+
if isinstance(first_element, int):
|
| 402 |
+
use_token_ids = True
|
| 403 |
+
prompt = request.prompt
|
| 404 |
+
elif isinstance(first_element, (str, list)):
|
| 405 |
+
# TODO: handles multiple prompt case in list[list[int]]
|
| 406 |
+
if len(request.prompt) > 1:
|
| 407 |
+
return create_error_response(
|
| 408 |
+
HTTPStatus.BAD_REQUEST,
|
| 409 |
+
"multiple prompts in a batch is not currently supported")
|
| 410 |
+
use_token_ids = not isinstance(first_element, str)
|
| 411 |
+
prompt = request.prompt[0]
|
| 412 |
+
else:
|
| 413 |
+
prompt = request.prompt
|
| 414 |
+
|
| 415 |
+
if use_token_ids:
|
| 416 |
+
_, error_check_ret = await check_length(request, prompt_ids=prompt)
|
| 417 |
+
else:
|
| 418 |
+
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
| 419 |
+
if error_check_ret is not None:
|
| 420 |
+
return error_check_ret
|
| 421 |
+
|
| 422 |
+
created_time = int(time.monotonic())
|
| 423 |
+
try:
|
| 424 |
+
# spaces_between_special_tokens = request.spaces_between_special_tokens
|
| 425 |
+
sampling_params = SamplingParams(
|
| 426 |
+
n=request.n,
|
| 427 |
+
best_of=request.best_of,
|
| 428 |
+
presence_penalty=request.presence_penalty,
|
| 429 |
+
frequency_penalty=request.frequency_penalty,
|
| 430 |
+
temperature=request.temperature,
|
| 431 |
+
top_p=request.top_p,
|
| 432 |
+
top_k=request.top_k,
|
| 433 |
+
stop=request.stop,
|
| 434 |
+
stop_token_ids=request.stop_token_ids,
|
| 435 |
+
ignore_eos=request.ignore_eos,
|
| 436 |
+
max_tokens=request.max_tokens,
|
| 437 |
+
logprobs=request.logprobs,
|
| 438 |
+
use_beam_search=request.use_beam_search,
|
| 439 |
+
skip_special_tokens=request.skip_special_tokens,
|
| 440 |
+
# spaces_between_special_tokens=spaces_between_special_tokens,
|
| 441 |
+
)
|
| 442 |
+
except ValueError as e:
|
| 443 |
+
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
| 444 |
+
|
| 445 |
+
if use_token_ids:
|
| 446 |
+
result_generator = engine.generate(None,
|
| 447 |
+
sampling_params,
|
| 448 |
+
request_id,
|
| 449 |
+
prompt_token_ids=prompt)
|
| 450 |
+
else:
|
| 451 |
+
result_generator = engine.generate(prompt, sampling_params, request_id,
|
| 452 |
+
token_ids)
|
| 453 |
+
|
| 454 |
+
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
| 455 |
+
# results. In addition, we do not stream the results when use beam search.
|
| 456 |
+
stream = (request.stream
|
| 457 |
+
and (request.best_of is None or request.n == request.best_of)
|
| 458 |
+
and not request.use_beam_search)
|
| 459 |
+
|
| 460 |
+
def create_stream_response_json(
|
| 461 |
+
index: int,
|
| 462 |
+
text: str,
|
| 463 |
+
logprobs: Optional[LogProbs] = None,
|
| 464 |
+
finish_reason: Optional[str] = None,
|
| 465 |
+
) -> str:
|
| 466 |
+
choice_data = CompletionResponseStreamChoice(
|
| 467 |
+
index=index,
|
| 468 |
+
text=text,
|
| 469 |
+
logprobs=logprobs,
|
| 470 |
+
finish_reason=finish_reason,
|
| 471 |
+
)
|
| 472 |
+
response = CompletionStreamResponse(
|
| 473 |
+
id=request_id,
|
| 474 |
+
created=created_time,
|
| 475 |
+
model=model_name,
|
| 476 |
+
choices=[choice_data],
|
| 477 |
+
)
|
| 478 |
+
response_json = response.json(ensure_ascii=False)
|
| 479 |
+
|
| 480 |
+
return response_json
|
| 481 |
+
|
| 482 |
+
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
| 483 |
+
previous_texts = [""] * request.n
|
| 484 |
+
previous_num_tokens = [0] * request.n
|
| 485 |
+
async for res in result_generator:
|
| 486 |
+
res: RequestOutput
|
| 487 |
+
for output in res.outputs:
|
| 488 |
+
i = output.index
|
| 489 |
+
delta_text = output.text[len(previous_texts[i]):]
|
| 490 |
+
if request.logprobs is not None:
|
| 491 |
+
logprobs = create_logprobs(
|
| 492 |
+
output.token_ids[previous_num_tokens[i]:],
|
| 493 |
+
output.logprobs[previous_num_tokens[i]:],
|
| 494 |
+
len(previous_texts[i]))
|
| 495 |
+
else:
|
| 496 |
+
logprobs = None
|
| 497 |
+
previous_texts[i] = output.text
|
| 498 |
+
previous_num_tokens[i] = len(output.token_ids)
|
| 499 |
+
response_json = create_stream_response_json(
|
| 500 |
+
index=i,
|
| 501 |
+
text=delta_text,
|
| 502 |
+
logprobs=logprobs,
|
| 503 |
+
)
|
| 504 |
+
yield f"data: {response_json}\n\n"
|
| 505 |
+
if output.finish_reason is not None:
|
| 506 |
+
logprobs = (LogProbs()
|
| 507 |
+
if request.logprobs is not None else None)
|
| 508 |
+
response_json = create_stream_response_json(
|
| 509 |
+
index=i,
|
| 510 |
+
text="",
|
| 511 |
+
logprobs=logprobs,
|
| 512 |
+
finish_reason=output.finish_reason,
|
| 513 |
+
)
|
| 514 |
+
yield f"data: {response_json}\n\n"
|
| 515 |
+
yield "data: [DONE]\n\n"
|
| 516 |
+
|
| 517 |
+
# Streaming response
|
| 518 |
+
if stream:
|
| 519 |
+
return StreamingResponse(completion_stream_generator(),
|
| 520 |
+
media_type="text/event-stream")
|
| 521 |
+
|
| 522 |
+
# Non-streaming response
|
| 523 |
+
final_res: RequestOutput = None
|
| 524 |
+
async for res in result_generator:
|
| 525 |
+
if await raw_request.is_disconnected():
|
| 526 |
+
# Abort the request if the client disconnects.
|
| 527 |
+
await engine.abort(request_id)
|
| 528 |
+
return create_error_response(HTTPStatus.BAD_REQUEST,
|
| 529 |
+
"Client disconnected")
|
| 530 |
+
final_res = res
|
| 531 |
+
assert final_res is not None
|
| 532 |
+
choices = []
|
| 533 |
+
for output in final_res.outputs:
|
| 534 |
+
if request.logprobs is not None:
|
| 535 |
+
logprobs = create_logprobs(output.token_ids, output.logprobs)
|
| 536 |
+
else:
|
| 537 |
+
logprobs = None
|
| 538 |
+
choice_data = CompletionResponseChoice(
|
| 539 |
+
index=output.index,
|
| 540 |
+
text=output.text,
|
| 541 |
+
logprobs=logprobs,
|
| 542 |
+
finish_reason=output.finish_reason,
|
| 543 |
+
)
|
| 544 |
+
choices.append(choice_data)
|
| 545 |
+
|
| 546 |
+
num_prompt_tokens = len(final_res.prompt_token_ids)
|
| 547 |
+
num_generated_tokens = sum(
|
| 548 |
+
len(output.token_ids) for output in final_res.outputs)
|
| 549 |
+
usage = UsageInfo(
|
| 550 |
+
prompt_tokens=num_prompt_tokens,
|
| 551 |
+
completion_tokens=num_generated_tokens,
|
| 552 |
+
total_tokens=num_prompt_tokens + num_generated_tokens,
|
| 553 |
+
)
|
| 554 |
+
response = CompletionResponse(
|
| 555 |
+
id=request_id,
|
| 556 |
+
created=created_time,
|
| 557 |
+
model=model_name,
|
| 558 |
+
choices=choices,
|
| 559 |
+
usage=usage,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
if request.stream:
|
| 563 |
+
# When user requests streaming but we don't stream, we still need to
|
| 564 |
+
# return a streaming response with a single event.
|
| 565 |
+
response_json = response.json(ensure_ascii=False)
|
| 566 |
+
|
| 567 |
+
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
| 568 |
+
yield f"data: {response_json}\n\n"
|
| 569 |
+
yield "data: [DONE]\n\n"
|
| 570 |
+
|
| 571 |
+
return StreamingResponse(fake_stream_generator(),
|
| 572 |
+
media_type="text/event-stream")
|
| 573 |
+
|
| 574 |
+
return response
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
if __name__ == "__main__":
|
| 578 |
+
parser = argparse.ArgumentParser(
|
| 579 |
+
description="vLLM OpenAI-Compatible RESTful API server.")
|
| 580 |
+
parser.add_argument("--host", type=str, default=None, help="host name")
|
| 581 |
+
parser.add_argument("--port", type=int, default=8000, help="port number")
|
| 582 |
+
parser.add_argument("--allow-credentials",
|
| 583 |
+
action="store_true",
|
| 584 |
+
help="allow credentials")
|
| 585 |
+
parser.add_argument("--allowed-origins",
|
| 586 |
+
type=json.loads,
|
| 587 |
+
default=["*"],
|
| 588 |
+
help="allowed origins")
|
| 589 |
+
parser.add_argument("--allowed-methods",
|
| 590 |
+
type=json.loads,
|
| 591 |
+
default=["*"],
|
| 592 |
+
help="allowed methods")
|
| 593 |
+
parser.add_argument("--allowed-headers",
|
| 594 |
+
type=json.loads,
|
| 595 |
+
default=["*"],
|
| 596 |
+
help="allowed headers")
|
| 597 |
+
parser.add_argument("--served-model-name",
|
| 598 |
+
type=str,
|
| 599 |
+
default=None,
|
| 600 |
+
help="The model name used in the API. If not "
|
| 601 |
+
"specified, the model name will be the same as "
|
| 602 |
+
"the huggingface name.")
|
| 603 |
+
|
| 604 |
+
parser = AsyncEngineArgs.add_cli_args(parser)
|
| 605 |
+
args = parser.parse_args()
|
| 606 |
+
|
| 607 |
+
app.add_middleware(
|
| 608 |
+
CORSMiddleware,
|
| 609 |
+
allow_origins=args.allowed_origins,
|
| 610 |
+
allow_credentials=args.allow_credentials,
|
| 611 |
+
allow_methods=args.allowed_methods,
|
| 612 |
+
allow_headers=args.allowed_headers,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
logger.info(f"args: {args}")
|
| 616 |
+
|
| 617 |
+
if args.served_model_name is not None:
|
| 618 |
+
served_model = args.served_model_name
|
| 619 |
+
else:
|
| 620 |
+
served_model = args.model
|
| 621 |
+
|
| 622 |
+
engine_args = AsyncEngineArgs.from_cli_args(args)
|
| 623 |
+
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 624 |
+
engine_model_config = asyncio.run(engine.get_model_config())
|
| 625 |
+
max_model_len = engine_model_config.max_model_len
|
| 626 |
+
|
| 627 |
+
# A separate tokenizer to map token IDs to strings.
|
| 628 |
+
tokenizer = get_tokenizer(engine_args.tokenizer,
|
| 629 |
+
tokenizer_mode=engine_args.tokenizer_mode,
|
| 630 |
+
trust_remote_code=engine_args.trust_remote_code)
|
| 631 |
+
|
| 632 |
+
uvicorn.run(app,
|
| 633 |
+
host=args.host,
|
| 634 |
+
port=args.port,
|
| 635 |
+
log_level="info",
|
| 636 |
+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReAct Agent Template
|
| 2 |
+
name: react_template
|
| 3 |
+
version: 0.0.1
|
| 4 |
+
type: react
|
| 5 |
+
description: A react agent capable of code interpreter
|
| 6 |
+
module_name: infiagent.agent.react
|
| 7 |
+
class_name: AsyncReactAgent
|
| 8 |
+
target_tasks:
|
| 9 |
+
- code interpreter
|
| 10 |
+
llm:
|
| 11 |
+
model_name: gpt-35-turbo
|
| 12 |
+
module_name: in f i a gen r.llm
|
| 13 |
+
class_name: AzureOpenAIGPTClient
|
| 14 |
+
params:
|
| 15 |
+
temperature: 0.2
|
| 16 |
+
top_p: 0.95
|
| 17 |
+
repetition_penalty: 1.0
|
| 18 |
+
max_tokens: 4096
|
| 19 |
+
prompt_template: !prompt ZeroShotReactPrompt
|
| 20 |
+
plugins:
|
| 21 |
+
- name: python_code_sandbox
|
| 22 |
+
type: tool
|
| 23 |
+
config: configs/tool_configs/async_python_code_sandbox.yaml
|
configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReAct Agent Template
|
| 2 |
+
name: gpt_4_react
|
| 3 |
+
version: 0.0.1
|
| 4 |
+
type: react
|
| 5 |
+
description: A react agent capable of code interpreter
|
| 6 |
+
module_name: infiagent.agent.react
|
| 7 |
+
class_name: AsyncReactAgent
|
| 8 |
+
target_tasks:
|
| 9 |
+
- code interpreter
|
| 10 |
+
llm:
|
| 11 |
+
model_name: gpt-4-0613
|
| 12 |
+
module_name: infiagent.llm
|
| 13 |
+
class_name: AzureOpenAIGPTClient
|
| 14 |
+
params:
|
| 15 |
+
temperature: 0.2
|
| 16 |
+
top_p: 0.95
|
| 17 |
+
repetition_penalty: 1.0
|
| 18 |
+
max_tokens: 4096
|
| 19 |
+
prompt_template: !prompt ZeroShotReactPrompt
|
| 20 |
+
plugins:
|
| 21 |
+
- name: python_code_sandbox
|
| 22 |
+
type: tool
|
| 23 |
+
config: configs/tool_configs/async_python_code_sandbox.yaml
|
configs/agent_configs/react_agent_azureopenai_gpt_4_async_dcoker.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReAct Agent Template
|
| 2 |
+
name: gpt_4_react
|
| 3 |
+
version: 0.0.1
|
| 4 |
+
type: react
|
| 5 |
+
description: A react agent capable of code interpreter
|
| 6 |
+
module_name: infiagent.agent.react
|
| 7 |
+
class_name: AsyncReactAgent
|
| 8 |
+
target_tasks:
|
| 9 |
+
- code interpreter
|
| 10 |
+
llm:
|
| 11 |
+
model_name: gpt-4-0613
|
| 12 |
+
module_name: infiagent.llm
|
| 13 |
+
class_name: AzureOpenAIGPTClient
|
| 14 |
+
params:
|
| 15 |
+
temperature: 0.2
|
| 16 |
+
top_p: 0.95
|
| 17 |
+
repetition_penalty: 1.0
|
| 18 |
+
max_tokens: 4096
|
| 19 |
+
prompt_template: !prompt ZeroShotReactPrompt
|
| 20 |
+
plugins:
|
| 21 |
+
- name: python_code_sandbox
|
| 22 |
+
type: tool
|
| 23 |
+
config: configs/tool_configs/async_python_code_sandbox_docker.yaml
|
configs/agent_configs/react_agent_gpt4_async.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReAct Agent Template
|
| 2 |
+
name: react_template
|
| 3 |
+
version: 0.0.1
|
| 4 |
+
type: react
|
| 5 |
+
description: A react agent capable of code interpreter
|
| 6 |
+
module_name: infiagent.agent.react
|
| 7 |
+
class_name: AsyncReactAgent
|
| 8 |
+
target_tasks:
|
| 9 |
+
- code interpreter
|
| 10 |
+
llm:
|
| 11 |
+
model_name: gpt-4
|
| 12 |
+
module_name: infiagent.llm
|
| 13 |
+
class_name: OpenAIGPTClient
|
| 14 |
+
params:
|
| 15 |
+
temperature: 0.0
|
| 16 |
+
top_p: 0.9
|
| 17 |
+
repetition_penalty: 1.0
|
| 18 |
+
max_tokens: 1024
|
| 19 |
+
prompt_template: !prompt ZeroShotReactPrompt
|
| 20 |
+
plugins:
|
| 21 |
+
- name: python_code_sandbox
|
| 22 |
+
type: tool
|
| 23 |
+
config: configs/tool_configs/async_python_code_sandbox.yaml
|
configs/agent_configs/react_agent_llama_async.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReAct Agent Template
|
| 2 |
+
name: react_template
|
| 3 |
+
version: 0.0.1
|
| 4 |
+
type: react
|
| 5 |
+
description: A react agent capable of code interpreter
|
| 6 |
+
module_name: infiagent.agent.react
|
| 7 |
+
class_name: AsyncReactAgent
|
| 8 |
+
target_tasks:
|
| 9 |
+
- code interpreter
|
| 10 |
+
llm:
|
| 11 |
+
model_name: meta-llama/Llama-2-7b-hf
|
| 12 |
+
module_name: infiagent.llm
|
| 13 |
+
class_name: LlamaOpenAIClient
|
| 14 |
+
params:
|
| 15 |
+
temperature: 0.0
|
| 16 |
+
top_p: 0.9
|
| 17 |
+
repetition_penalty: 1.0
|
| 18 |
+
max_tokens: 1024
|
| 19 |
+
prompt_template: !prompt ZeroShotReactPrompt
|
| 20 |
+
plugins:
|
| 21 |
+
- name: python_code_sandbox
|
| 22 |
+
type: tool
|
| 23 |
+
config: configs/tool_configs/async_python_code_sandbox.yaml
|
configs/agent_configs/react_agent_opt_async.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReAct Agent Template
|
| 2 |
+
name: react_template
|
| 3 |
+
version: 0.0.1
|
| 4 |
+
type: react
|
| 5 |
+
description: A react agent capable of code interpreter
|
| 6 |
+
module_name: infiagent.agent.react
|
| 7 |
+
class_name: AsyncReactAgent
|
| 8 |
+
target_tasks:
|
| 9 |
+
- code interpreter
|
| 10 |
+
llm:
|
| 11 |
+
model_name: facebook/opt-125m
|
| 12 |
+
module_name: infiagent.llm
|
| 13 |
+
class_name: OptOpenAIClient
|
| 14 |
+
params:
|
| 15 |
+
temperature: 0.0
|
| 16 |
+
top_p: 0.9
|
| 17 |
+
repetition_penalty: 1.0
|
| 18 |
+
max_tokens: 1024
|
| 19 |
+
prompt_template: !prompt ZeroShotReactPrompt
|
| 20 |
+
plugins:
|
| 21 |
+
- name: python_code_sandbox
|
| 22 |
+
type: tool
|
| 23 |
+
config: configs/tool_configs/async_python_code_sandbox.yaml
|
configs/tool_configs/async_python_code_sandbox.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: python_code_sandbox
|
| 2 |
+
version: 0.0.1
|
| 3 |
+
type: tool
|
| 4 |
+
description: this tool can help to run python script with python code as input
|
| 5 |
+
module_name: infiagent.tools
|
| 6 |
+
class_name: AsyncPythonSandBoxTool
|
| 7 |
+
session_id: none
|
configs/tool_configs/async_python_code_sandbox_docker.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: python_code_sandbox
|
| 2 |
+
version: 0.0.1
|
| 3 |
+
type: tool
|
| 4 |
+
description: this tool can help to run python script with python code as input
|
| 5 |
+
module_name: infiagent.tools
|
| 6 |
+
class_name: CodeTool
|
| 7 |
+
session_id: none
|
run.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -ex
|
| 3 |
+
poetry run python3 -m uvicorn src.activities.api:app --reload --host 0.0.0.0 --port ${PORT:-3000} --limit-max-requests 5000 --timeout-keep-alive 1200
|
run_demo.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# set -ex
|
| 3 |
+
|
| 4 |
+
streamlit run ./activities/local_demo.py --server.port 6006 -- $@
|
| 5 |
+
|
run_local.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -ex
|
| 3 |
+
poetry run python3 -m uvicorn src.activities.local_test:local_app --reload --host 0.0.0.0 --port ${PORT:-3000} --limit-max-requests 5000 --timeout-keep-alive 1200
|
| 4 |
+
|
setup.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name='infiagent',
|
| 5 |
+
version='0.1.0',
|
| 6 |
+
author='InfiAgent',
|
| 7 |
+
packages=find_packages(where='src'),
|
| 8 |
+
package_dir={'': 'src'},
|
| 9 |
+
url='https://github.com/InfiAgent/ADA-Agent',
|
| 10 |
+
license='LICENSE.txt',
|
| 11 |
+
description='An awesome package for InfiAgent.',
|
| 12 |
+
long_description=open('README.md').read(),
|
| 13 |
+
package_data={
|
| 14 |
+
'infiagent.configs.agent_configs': ['*.yaml'],
|
| 15 |
+
'infiagent.configs.tool_configs': ['*.yaml'],
|
| 16 |
+
},
|
| 17 |
+
install_requires=[
|
| 18 |
+
"streamlit",
|
| 19 |
+
"pyyaml",
|
| 20 |
+
"pytest",
|
| 21 |
+
"openai==0.27.7",
|
| 22 |
+
"fastapi",
|
| 23 |
+
"uvicorn",
|
| 24 |
+
"uvloop",
|
| 25 |
+
"watchdog",
|
| 26 |
+
"chardet",
|
| 27 |
+
"werkzeug",
|
| 28 |
+
"python-dotenv",
|
| 29 |
+
"motor",
|
| 30 |
+
"aiofiles",
|
| 31 |
+
"sse_starlette",
|
| 32 |
+
"loguru",
|
| 33 |
+
"jupyter_client",
|
| 34 |
+
"pandas",
|
| 35 |
+
"scikit-learn",
|
| 36 |
+
"scipy",
|
| 37 |
+
"ipykernel"
|
| 38 |
+
],
|
| 39 |
+
python_requires='>=3.9'
|
| 40 |
+
)
|
src/infiagent/__init__.py
ADDED
|
File without changes
|
src/infiagent/agent/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_agent import BaseAgent
|
| 2 |
+
from .react import AsyncReactAgent
|
src/infiagent/agent/base_agent.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Dict, Callable, Union, AsyncGenerator
|
| 4 |
+
|
| 5 |
+
from ..exceptions.exceptions import InputErrorException
|
| 6 |
+
from ..prompt import PromptTemplate
|
| 7 |
+
from ..schemas import AgentOutput, AgentType, AgentResponse
|
| 8 |
+
|
| 9 |
+
from ..llm.base_llm import BaseLLM
|
| 10 |
+
|
| 11 |
+
from ..tools import BaseTool
|
| 12 |
+
from ..utils import Config, get_logger
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
from importlib import import_module
|
| 16 |
+
|
| 17 |
+
logger = get_logger()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
LLM_CONF_OVERRIDE_KEY = ['psm', 'dc', 'temperature', 'top_p', 'top_k', 'max_tokens']
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BaseAgent(ABC):
|
| 24 |
+
"""Base Agent class defining the essential attributes and methods for an ALM Agent.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, **kwargs):
|
| 28 |
+
"""
|
| 29 |
+
Initializes an instance of the Agent class.
|
| 30 |
+
"""
|
| 31 |
+
# Set default values
|
| 32 |
+
default_config = {
|
| 33 |
+
'name': 'agent',
|
| 34 |
+
'type': AgentType.react,
|
| 35 |
+
'version': '',
|
| 36 |
+
'description': '',
|
| 37 |
+
'prompt_template': None,
|
| 38 |
+
'auth': {}
|
| 39 |
+
}
|
| 40 |
+
# Update default values with provided config
|
| 41 |
+
default_config.update(kwargs)
|
| 42 |
+
|
| 43 |
+
# Access configuration data with a known default value
|
| 44 |
+
auth = default_config['auth']
|
| 45 |
+
self._set_auth_env(auth)
|
| 46 |
+
|
| 47 |
+
self._name: str = default_config['name']
|
| 48 |
+
self._type: AgentType = default_config['type']
|
| 49 |
+
self._version: str = default_config['version']
|
| 50 |
+
self._description: str = default_config['description']
|
| 51 |
+
self.__prompt_template: Union[PromptTemplate, None] = \
|
| 52 |
+
self._get_prompt_template(default_config['prompt_template'])
|
| 53 |
+
self.__llm: Union[BaseLLM, None] = None
|
| 54 |
+
self.__plugins_map: Dict = {}
|
| 55 |
+
self.__plugin_tool_function = {}
|
| 56 |
+
self.__plugin_tool_async_function = {}
|
| 57 |
+
self.__plugin_tool_description = None
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def name(self) -> str:
|
| 61 |
+
return self._name
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def type(self) -> AgentType:
|
| 65 |
+
return self._type
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def version(self) -> str:
|
| 69 |
+
return self._version
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def description(self) -> str:
|
| 73 |
+
return self._description
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def prompt_template(self) -> PromptTemplate:
|
| 77 |
+
return self.__prompt_template
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def llm(self) -> Union[BaseLLM, None]:
|
| 81 |
+
return self.__llm
|
| 82 |
+
|
| 83 |
+
@llm.setter
|
| 84 |
+
def llm(self, llm_client: BaseLLM):
|
| 85 |
+
if llm_client is None or not isinstance(llm_client, BaseLLM):
|
| 86 |
+
raise InputErrorException("Invalid llm client {}".format(type(llm_client)))
|
| 87 |
+
self.__llm = llm_client
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def plugins_map(self) -> Dict:
|
| 91 |
+
return self.__plugins_map.copy() # Return a copy to prevent external modification
|
| 92 |
+
|
| 93 |
+
def add_plugin(self, tool_name: str, tool):
|
| 94 |
+
if not tool_name or not tool:
|
| 95 |
+
raise InputErrorException("Adding invalid tool name: {}, type {}".format(tool_name, type(tool)))
|
| 96 |
+
self.__plugins_map[tool_name] = tool
|
| 97 |
+
|
| 98 |
+
def _set_auth_env(self, obj):
|
| 99 |
+
"""This method sets environment variables for authentication.
|
| 100 |
+
"""
|
| 101 |
+
for key in obj:
|
| 102 |
+
os.environ[key] = obj.get(key)
|
| 103 |
+
|
| 104 |
+
def _get_prompt_template(self, obj):
|
| 105 |
+
"""This method returns a prompt template instance based on the provided configuration.
|
| 106 |
+
"""
|
| 107 |
+
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
|
| 108 |
+
if isinstance(obj, dict):
|
| 109 |
+
return {
|
| 110 |
+
key: self._parse_prompt_template(obj[key]) for key in obj
|
| 111 |
+
}
|
| 112 |
+
elif isinstance(obj, PromptTemplate):
|
| 113 |
+
ans = self._parse_prompt_template(obj)
|
| 114 |
+
return ans
|
| 115 |
+
else:
|
| 116 |
+
raise InputErrorException("Invalid PromptTemplate, it should be a dict or PromptTemplate. But get {}"
|
| 117 |
+
.format(type(obj)))
|
| 118 |
+
|
| 119 |
+
def _parse_prompt_template(self, obj: Union[dict, PromptTemplate]):
|
| 120 |
+
"""This method parses the prompt template configuration and returns a prompt template instance.
|
| 121 |
+
"""
|
| 122 |
+
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
|
| 123 |
+
if isinstance(obj, PromptTemplate):
|
| 124 |
+
return obj
|
| 125 |
+
return PromptTemplate(input_variables=obj['input_variables'],
|
| 126 |
+
template=obj['template'],
|
| 127 |
+
validate_template=bool(obj.get('validate_template', True)))
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def _get_basic_instance_from_config(cls, config_data):
|
| 131 |
+
agent_module_name = config_data.get("module_name", None)
|
| 132 |
+
agent_class_name = config_data.get("class_name", None)
|
| 133 |
+
if not agent_module_name or not agent_class_name:
|
| 134 |
+
raise InputErrorException("Agent module_name and class_name required, please check your config")
|
| 135 |
+
|
| 136 |
+
module = import_module(agent_module_name)
|
| 137 |
+
clazz = getattr(module, agent_class_name)
|
| 138 |
+
agent_instance = clazz(**config_data)
|
| 139 |
+
return agent_instance
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
def from_config_path_and_kwargs(cls, config_path, **kwargs):
|
| 143 |
+
config_data = Config.load(config_path)
|
| 144 |
+
logger.info(f"Use config from path {config_path} to init agent : {config_data}")
|
| 145 |
+
agent_instance = cls._get_basic_instance_from_config(config_data)
|
| 146 |
+
|
| 147 |
+
if 'llm' in config_data and 'params' in config_data['llm']:
|
| 148 |
+
for param in LLM_CONF_OVERRIDE_KEY:
|
| 149 |
+
if param in kwargs and kwargs[param]:
|
| 150 |
+
logger.info(f"Overwrite with new {param} {kwargs[param]}")
|
| 151 |
+
config_data['llm']['params'][param] = kwargs[param]
|
| 152 |
+
|
| 153 |
+
assert isinstance(agent_instance, BaseAgent)
|
| 154 |
+
agent_instance._init_llm(config_data.get("llm", {}))
|
| 155 |
+
agent_instance._init_plugins(config_data.get('plugins', []))
|
| 156 |
+
return agent_instance
|
| 157 |
+
|
| 158 |
+
def _init_llm(self, obj):
|
| 159 |
+
"""
|
| 160 |
+
This method parses the Language Model Manager (LLM) configuration and returns an LLM instance.
|
| 161 |
+
|
| 162 |
+
:param obj: A configuration dictionary or string.
|
| 163 |
+
:type obj: dict or str
|
| 164 |
+
:raises ValueError: If the specified LLM is not supported.
|
| 165 |
+
:return: An LLM instance.
|
| 166 |
+
:rtype: BaseLLM
|
| 167 |
+
"""
|
| 168 |
+
if isinstance(obj, str):
|
| 169 |
+
name = obj
|
| 170 |
+
model_params = dict()
|
| 171 |
+
else:
|
| 172 |
+
name = obj.get('model_name', None)
|
| 173 |
+
model_params = obj.get('params', dict())
|
| 174 |
+
|
| 175 |
+
module_name = obj['module_name']
|
| 176 |
+
class_name = obj['class_name']
|
| 177 |
+
|
| 178 |
+
module = import_module(module_name)
|
| 179 |
+
clazz = getattr(module, class_name)
|
| 180 |
+
|
| 181 |
+
llm = clazz(model_name=name, params=model_params)
|
| 182 |
+
self.llm = llm
|
| 183 |
+
|
| 184 |
+
def _init_plugins(self, configs):
|
| 185 |
+
"""
|
| 186 |
+
This method parses the plugin configuration and add each plugin into the plugins_map.
|
| 187 |
+
"""
|
| 188 |
+
assert isinstance(configs, list)
|
| 189 |
+
for plugin_config in configs:
|
| 190 |
+
if plugin_config.get('type', "") == 'agent':
|
| 191 |
+
# Agent as plugin
|
| 192 |
+
agent = BaseAgent.from_config_path_and_kwargs(plugin_config['config'])
|
| 193 |
+
self.plugins_map[plugin_config['name']] = agent
|
| 194 |
+
else:
|
| 195 |
+
# Tools as plugin
|
| 196 |
+
params = plugin_config.get('params', dict())
|
| 197 |
+
tool = BaseTool.from_config(config_input=plugin_config['config'], **params)
|
| 198 |
+
self.plugins_map[tool.name] = tool
|
| 199 |
+
|
| 200 |
+
@classmethod
|
| 201 |
+
async def async_from_config_path_and_kwargs(cls, config_path, **kwargs):
|
| 202 |
+
config_data = Config.load(config_path)
|
| 203 |
+
logger.info(f"Use config from path {config_path} to init agent : {config_data}")
|
| 204 |
+
agent_instance = cls._get_basic_instance_from_config(config_data)
|
| 205 |
+
|
| 206 |
+
# override default config with user input
|
| 207 |
+
if 'llm' in config_data and 'params' in config_data['llm']:
|
| 208 |
+
for param in LLM_CONF_OVERRIDE_KEY:
|
| 209 |
+
if param in kwargs and kwargs[param]:
|
| 210 |
+
logger.info(f"Overwrite with new {param} {kwargs[param]}")
|
| 211 |
+
config_data['llm']['params'][param] = kwargs[param]
|
| 212 |
+
|
| 213 |
+
# Create tasks for llm and each individual plugin
|
| 214 |
+
llm_config = config_data.get("llm", {})
|
| 215 |
+
plugin_configs = config_data.get('plugins', [])
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# Create tasks for llm and each individual plugin
|
| 219 |
+
llm_task = asyncio.create_task(cls._async_init_llm(llm_config))
|
| 220 |
+
plugin_tasks = [asyncio.create_task(cls._async_init_plugin(plugin_config)) for
|
| 221 |
+
plugin_config in plugin_configs]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Gather results
|
| 225 |
+
llm, *plugins = await asyncio.gather(llm_task, *plugin_tasks)
|
| 226 |
+
|
| 227 |
+
agent_instance.llm = llm
|
| 228 |
+
for plugin in plugins:
|
| 229 |
+
plugin_name, plugin_instance = plugin
|
| 230 |
+
agent_instance.add_plugin(plugin_name, plugin_instance)
|
| 231 |
+
return agent_instance
|
| 232 |
+
|
| 233 |
+
@classmethod
|
| 234 |
+
async def _async_init_llm(cls, llm_config):
|
| 235 |
+
llm_model_name = llm_config.get("module_name", None)
|
| 236 |
+
llm_class_name = llm_config.get("class_name", None)
|
| 237 |
+
if not llm_model_name or not llm_class_name:
|
| 238 |
+
raise InputErrorException("Agent LLM module_name and class_name required, please check your config")
|
| 239 |
+
module = import_module(llm_model_name)
|
| 240 |
+
clazz = getattr(module, llm_class_name)
|
| 241 |
+
assert issubclass(clazz, BaseLLM), f"{clazz} is not a subclass of BaseLLM"
|
| 242 |
+
llm_instance = await clazz.create(config_data=llm_config)
|
| 243 |
+
return llm_instance
|
| 244 |
+
|
| 245 |
+
@classmethod
|
| 246 |
+
async def _async_init_plugin(cls, plugin_config):
|
| 247 |
+
|
| 248 |
+
if plugin_config.get('type', "") == 'agent':
|
| 249 |
+
# Agent as plugin
|
| 250 |
+
agent = await BaseAgent.async_from_config_path_and_kwargs(plugin_config['config'])
|
| 251 |
+
return plugin_config['name'], agent
|
| 252 |
+
else:
|
| 253 |
+
# Tool as plugin
|
| 254 |
+
params = plugin_config.get('params', dict())
|
| 255 |
+
name = plugin_config.get('name', None)
|
| 256 |
+
config = plugin_config['config']
|
| 257 |
+
|
| 258 |
+
tool = await BaseTool.async_from_config(config_input=config, **params)
|
| 259 |
+
|
| 260 |
+
if name is None:
|
| 261 |
+
name = tool.name
|
| 262 |
+
logger.info("Init tool with name [{}], and description [{}]".format(name, tool.description))
|
| 263 |
+
return name, tool
|
| 264 |
+
|
| 265 |
+
@abstractmethod
|
| 266 |
+
def run(self, *args, **kwargs) -> [AgentResponse, None]:
|
| 267 |
+
"""Abstract method to be overridden by child classes for running the agent.
|
| 268 |
+
|
| 269 |
+
:return: The output of the agent.
|
| 270 |
+
:rtype: AgentOutput
|
| 271 |
+
"""
|
| 272 |
+
pass
|
| 273 |
+
|
| 274 |
+
async def async_run(self, *args, **kwargs) -> AsyncGenerator[AgentResponse, None]:
|
| 275 |
+
"""Abstract method to be overridden by child classes for running the agent.
|
| 276 |
+
|
| 277 |
+
:return: The output of the agent.
|
| 278 |
+
"""
|
| 279 |
+
yield self.run(*args, **kwargs)
|
| 280 |
+
|
| 281 |
+
def _get_plugin_function_map(self, method_name: str) -> Dict[str, Callable]:
|
| 282 |
+
if method_name == "run" and self.__plugin_tool_function:
|
| 283 |
+
return self.__plugin_tool_function
|
| 284 |
+
elif method_name == "async_run" and self.__plugin_tool_async_function:
|
| 285 |
+
return self.__plugin_tool_async_function
|
| 286 |
+
|
| 287 |
+
function_map = {}
|
| 288 |
+
|
| 289 |
+
for name, plugin_tool in self.plugins_map.items():
|
| 290 |
+
if isinstance(plugin_tool, (BaseTool, BaseAgent)):
|
| 291 |
+
function_map[name] = getattr(plugin_tool, method_name)
|
| 292 |
+
else:
|
| 293 |
+
logger.warning(f"No support for plugin name {name} of type {type(plugin_tool)}")
|
| 294 |
+
|
| 295 |
+
if method_name == "run":
|
| 296 |
+
self.__plugin_tool_function = function_map
|
| 297 |
+
elif method_name == "async_run":
|
| 298 |
+
self.__plugin_tool_async_function = function_map
|
| 299 |
+
|
| 300 |
+
return function_map
|
| 301 |
+
|
| 302 |
+
def get_plugin_tool_function(self) -> Dict[str, Callable]:
|
| 303 |
+
"""Format the function map for the function API.
|
| 304 |
+
|
| 305 |
+
:return: The function map.
|
| 306 |
+
:rtype: Dict[str, Callable]
|
| 307 |
+
"""
|
| 308 |
+
return self._get_plugin_function_map("run")
|
| 309 |
+
|
| 310 |
+
def get_plugin_tool_async_function(self) -> Dict[str, Callable]:
|
| 311 |
+
"""Format the function map for the function API.
|
| 312 |
+
|
| 313 |
+
:return: The function map.
|
| 314 |
+
:rtype: Dict[str, Callable]
|
| 315 |
+
"""
|
| 316 |
+
return self._get_plugin_function_map("async_run")
|
| 317 |
+
|
| 318 |
+
def _get_plugin_description(self):
|
| 319 |
+
if self.__plugin_tool_description:
|
| 320 |
+
return self.__plugin_tool_description
|
| 321 |
+
|
| 322 |
+
descriptions = ""
|
| 323 |
+
try:
|
| 324 |
+
for plugin_name, plugin in self.plugins_map.items():
|
| 325 |
+
descriptions += f"{plugin_name}[input]: {plugin.description}\n"
|
| 326 |
+
except Exception as e:
|
| 327 |
+
err_msg = "Failed to get plugin tool name and description. error: {}".format(str(e))
|
| 328 |
+
raise InputErrorException(err_msg) from e
|
| 329 |
+
|
| 330 |
+
self.__plugin_tool_description = descriptions
|
| 331 |
+
return descriptions
|
| 332 |
+
|
| 333 |
+
def clear(self):
|
| 334 |
+
"""
|
| 335 |
+
Clear and reset the agent.
|
| 336 |
+
"""
|
| 337 |
+
pass
|
src/infiagent/agent/react/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .async_react_agent import AsyncReactAgent
|
| 2 |
+
__all__ = [
|
| 3 |
+
'AsyncReactAgent'
|
| 4 |
+
]
|
src/infiagent/agent/react/async_react_agent.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import time
|
| 3 |
+
from typing import Union, List, Dict
|
| 4 |
+
|
| 5 |
+
from werkzeug.datastructures import FileStorage
|
| 6 |
+
|
| 7 |
+
from .. import BaseAgent
|
| 8 |
+
from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException
|
| 9 |
+
from ...schemas import (
|
| 10 |
+
AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse,
|
| 11 |
+
BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile
|
| 12 |
+
)
|
| 13 |
+
from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool
|
| 14 |
+
from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \
|
| 15 |
+
OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \
|
| 16 |
+
TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN
|
| 17 |
+
|
| 18 |
+
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
|
| 19 |
+
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
|
| 20 |
+
CODE_BLOCK_START_TAG = '```python'
|
| 21 |
+
CODE_BLOCK_TAG = '```'
|
| 22 |
+
|
| 23 |
+
logger = get_logger()
|
| 24 |
+
|
| 25 |
+
SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
|
| 26 |
+
FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
|
| 27 |
+
CODE_BLOCK_START_TAG = '```python'
|
| 28 |
+
CODE_BLOCK_TAG = '```'
|
| 29 |
+
STOP_WORD = ['Observation:']
|
| 30 |
+
|
| 31 |
+
logger = get_logger()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AsyncReactAgent(BaseAgent):
|
| 35 |
+
def __init__(self, **kwargs):
|
| 36 |
+
super().__init__(**kwargs)
|
| 37 |
+
self._name = self._name or "AsyncReactAgent"
|
| 38 |
+
self._type = AgentType.react
|
| 39 |
+
self.__intermediate_steps: List[BaseAgentResponse] = []
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def intermediate_steps(self):
|
| 43 |
+
return self.__intermediate_steps
|
| 44 |
+
|
| 45 |
+
def run(self, *args, **kwargs):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]):
|
| 49 |
+
sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME)
|
| 50 |
+
if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)):
|
| 51 |
+
raise InternalErrorException("SandBox client is not ready for agent, please check init logic.")
|
| 52 |
+
return await sandbox_plugin.sync_to_sandbox(file)
|
| 53 |
+
|
| 54 |
+
async def async_run(self, agent_req: AgentRequest):
|
| 55 |
+
instruction = '\n'.join(message.content for message in agent_req.messages)
|
| 56 |
+
async for response in self._chat(instruction, is_cn=agent_req.is_cn):
|
| 57 |
+
yield response
|
| 58 |
+
|
| 59 |
+
async def _chat(self, instruction: str, is_cn=False, max_iterations=10,
|
| 60 |
+
max_single_step_iterations=3):
|
| 61 |
+
current_iteration = 0
|
| 62 |
+
|
| 63 |
+
for _ in range(max_iterations):
|
| 64 |
+
current_iteration += 1
|
| 65 |
+
llm_response = await self._single_round_thought(instruction,
|
| 66 |
+
max_llm_iteration=max_single_step_iterations,
|
| 67 |
+
is_cn=is_cn)
|
| 68 |
+
logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n"
|
| 69 |
+
.format(current_iteration, max_iterations, llm_response.raw_output,
|
| 70 |
+
llm_response.formatted_output))
|
| 71 |
+
yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output)
|
| 72 |
+
|
| 73 |
+
if isinstance(llm_response, AgentFinish):
|
| 74 |
+
logger.info("Find final answer, stop iteration.")
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
self.intermediate_steps.append(llm_response)
|
| 78 |
+
action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration,
|
| 79 |
+
max_iterations, is_cn)
|
| 80 |
+
logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n"
|
| 81 |
+
.format(current_iteration, max_iterations, action_response.raw_output,
|
| 82 |
+
action_response.formatted_output))
|
| 83 |
+
self.intermediate_steps.append(action_response)
|
| 84 |
+
|
| 85 |
+
yield self.create_agent_response(action_response.formatted_output,
|
| 86 |
+
cur_output_files,
|
| 87 |
+
action_response.raw_output)
|
| 88 |
+
|
| 89 |
+
logger.info(f"Finished iteration in {current_iteration}.")
|
| 90 |
+
|
| 91 |
+
# TODO update logic to not be sandbox specific, sandbox related logic should be handled in sandbox client
|
| 92 |
+
async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False):
|
| 93 |
+
try:
|
| 94 |
+
response.tool = 'python_code_sandbox'
|
| 95 |
+
action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input)
|
| 96 |
+
logger.info(
|
| 97 |
+
f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n"
|
| 98 |
+
f"{action_response.output_text}")
|
| 99 |
+
|
| 100 |
+
if "STDERR" in action_response.output_text:
|
| 101 |
+
formatted_output = self._process_sandbox_output(action_response.output_text)
|
| 102 |
+
else:
|
| 103 |
+
formatted_output = action_response.output_text
|
| 104 |
+
|
| 105 |
+
formatted_output = replace_latex_format(formatted_output)
|
| 106 |
+
observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN
|
| 107 |
+
formatted_output = f"{observation_prefix}\n{formatted_output}\n"
|
| 108 |
+
|
| 109 |
+
action_observation = AgentObservation(tool=response.tool,
|
| 110 |
+
formatted_output=formatted_output,
|
| 111 |
+
raw_output=action_response.output_text)
|
| 112 |
+
cur_output_files = self._get_output_files(action_response)
|
| 113 |
+
return action_observation, cur_output_files
|
| 114 |
+
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. "
|
| 117 |
+
f"Error: {str(e)}", exc_info=True)
|
| 118 |
+
# TODO: We hard code here as we only have one tool
|
| 119 |
+
raise SandboxException("Error occurred while running the tool") from e
|
| 120 |
+
|
| 121 |
+
def _compose_prompt(self, instruction) -> str:
|
| 122 |
+
"""
|
| 123 |
+
Compose the prompt from template, worker description, examples and instruction.
|
| 124 |
+
"""
|
| 125 |
+
agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps)
|
| 126 |
+
tool_description = self._get_plugin_description()
|
| 127 |
+
tool_names = ", ".join(list(self.plugins_map.keys()))
|
| 128 |
+
if self.prompt_template is None:
|
| 129 |
+
raise InternalErrorException("Agent prompt is none, please check init process")
|
| 130 |
+
|
| 131 |
+
return self.prompt_template.format(
|
| 132 |
+
instruction=instruction,
|
| 133 |
+
agent_scratchpad=agent_scratchpad,
|
| 134 |
+
tool_description=tool_description,
|
| 135 |
+
tool_names=tool_names
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \
|
| 139 |
+
Union[AgentAction, AgentFinish]:
|
| 140 |
+
|
| 141 |
+
llm_iteration_count = 0
|
| 142 |
+
|
| 143 |
+
llm_response = None
|
| 144 |
+
while llm_iteration_count <= max_llm_iteration:
|
| 145 |
+
llm_iteration_count += 1
|
| 146 |
+
try:
|
| 147 |
+
llm_response = await self._get_llm_response(instruction)
|
| 148 |
+
action_response = self._parse_output(llm_response.content, is_cn)
|
| 149 |
+
|
| 150 |
+
return action_response
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error("LLM iteration {} out of {} failed. Error: {}".
|
| 153 |
+
format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True)
|
| 154 |
+
|
| 155 |
+
if llm_iteration_count > max_llm_iteration:
|
| 156 |
+
logger.error("LLM iteration {} exceed max retry {}. Aborting".
|
| 157 |
+
format(llm_iteration_count, max_llm_iteration))
|
| 158 |
+
return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN,
|
| 159 |
+
raw_output=str(llm_response))
|
| 160 |
+
|
| 161 |
+
async def _get_llm_response(self, instruction: str):
|
| 162 |
+
prompt = self._compose_prompt(instruction)
|
| 163 |
+
logger.info("Send prompt to LLM:\n{}".format(prompt))
|
| 164 |
+
response = await self.llm.async_completion(prompt)
|
| 165 |
+
if response.state == "error":
|
| 166 |
+
raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content)))
|
| 167 |
+
|
| 168 |
+
logger.info("Got response from llm, raw response content: \n{}".format(response.content))
|
| 169 |
+
return response
|
| 170 |
+
|
| 171 |
+
def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]:
|
| 172 |
+
|
| 173 |
+
for stop_word in STOP_WORD:
|
| 174 |
+
if stop_word in llm_output:
|
| 175 |
+
llm_output = llm_output.split(stop_word)[0].rstrip()
|
| 176 |
+
break
|
| 177 |
+
|
| 178 |
+
# Check for Final Answer, if it is final, then just return
|
| 179 |
+
for indicator in FINAL_ANSWER_INDICATORS:
|
| 180 |
+
if indicator in llm_output:
|
| 181 |
+
# got final answer and remove the indicator
|
| 182 |
+
parts = llm_output.split(indicator)
|
| 183 |
+
# formatted_output = ''.join(parts[:-1]).strip()
|
| 184 |
+
formatted_output = ''.join(parts).strip()
|
| 185 |
+
formatted_output = replace_latex_format(formatted_output)
|
| 186 |
+
return AgentFinish(raw_output=llm_output, formatted_output=formatted_output)
|
| 187 |
+
|
| 188 |
+
# Updated regex pattern for capturing the expected input format
|
| 189 |
+
ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
|
| 190 |
+
ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
|
| 191 |
+
|
| 192 |
+
action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL)
|
| 193 |
+
|
| 194 |
+
# Find action, context, and action input, build action response
|
| 195 |
+
if action_match:
|
| 196 |
+
context = action_match.group(1).strip()
|
| 197 |
+
action_tool_description = action_match.group(2).strip()
|
| 198 |
+
action_input = action_match.group(3).strip()
|
| 199 |
+
|
| 200 |
+
# Format code
|
| 201 |
+
# TODO: currently we only have one plugin which is sandbox, update to support multiple tools
|
| 202 |
+
format_code_block = self._format_code_block(action_input)
|
| 203 |
+
|
| 204 |
+
prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN
|
| 205 |
+
formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block)
|
| 206 |
+
formatted_output = replace_latex_format(formatted_output)
|
| 207 |
+
|
| 208 |
+
return AgentAction(tool=action_tool_description,
|
| 209 |
+
tool_input=format_code_block,
|
| 210 |
+
formatted_output=formatted_output,
|
| 211 |
+
raw_output=llm_output)
|
| 212 |
+
|
| 213 |
+
# Not final answer and not action, raise exception
|
| 214 |
+
if not re.search(r"Action\s*:", llm_output, re.DOTALL):
|
| 215 |
+
raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`")
|
| 216 |
+
elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL):
|
| 217 |
+
raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`")
|
| 218 |
+
else:
|
| 219 |
+
raise LLMException(f"Unrecognized LLM output format: `{llm_output}`")
|
| 220 |
+
|
| 221 |
+
def _format_code_block(self, tool_input):
|
| 222 |
+
stripped_tool_input = tool_input.strip()
|
| 223 |
+
|
| 224 |
+
if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
|
| 225 |
+
if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'):
|
| 226 |
+
stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \
|
| 227 |
+
'\n'
|
| 228 |
+
formatted_code = stripped_tool_input
|
| 229 |
+
elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith(
|
| 230 |
+
CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
|
| 231 |
+
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n'
|
| 232 |
+
else:
|
| 233 |
+
formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n'
|
| 234 |
+
|
| 235 |
+
return formatted_code.encode("utf-8").decode("utf-8")
|
| 236 |
+
|
| 237 |
+
def _process_sandbox_output(self, output: str):
|
| 238 |
+
"""Function to process the result containing STDERR."""
|
| 239 |
+
if len(output) <= 1000:
|
| 240 |
+
return output
|
| 241 |
+
|
| 242 |
+
logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}".
|
| 243 |
+
format(output))
|
| 244 |
+
rows = output.split("\n")
|
| 245 |
+
# Get the first 500 characters, respecting line boundaries
|
| 246 |
+
top_segment = []
|
| 247 |
+
length = 0
|
| 248 |
+
for sub_p in rows:
|
| 249 |
+
if length + len(sub_p) > 500:
|
| 250 |
+
break
|
| 251 |
+
top_segment.append(sub_p)
|
| 252 |
+
length += len(sub_p)
|
| 253 |
+
|
| 254 |
+
# Get the last 500 characters, respecting line boundaries
|
| 255 |
+
bottom_segment = []
|
| 256 |
+
length = 0
|
| 257 |
+
for sub_p in reversed(rows):
|
| 258 |
+
if length + len(sub_p) > 500:
|
| 259 |
+
break
|
| 260 |
+
bottom_segment.insert(0, sub_p)
|
| 261 |
+
length += len(sub_p)
|
| 262 |
+
|
| 263 |
+
# Combine the segments with "......" in between
|
| 264 |
+
timed_output = "\n".join(top_segment + ["......"] + bottom_segment)
|
| 265 |
+
|
| 266 |
+
return timed_output
|
| 267 |
+
|
| 268 |
+
def _get_output_files(self, tool_response) -> list[MediaFile]:
|
| 269 |
+
output_files = []
|
| 270 |
+
|
| 271 |
+
if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput):
|
| 272 |
+
raw_output = tool_response.raw_output
|
| 273 |
+
|
| 274 |
+
if raw_output.code == 0 and not raw_output.data.is_partial:
|
| 275 |
+
result_data = raw_output.data.result
|
| 276 |
+
|
| 277 |
+
# TODO confirm if we still need output and format
|
| 278 |
+
if len(result_data.new_generated_files) > 0:
|
| 279 |
+
output_files.extend([MediaFile(tos_path=file.download_link) for file in
|
| 280 |
+
result_data.new_generated_files])
|
| 281 |
+
|
| 282 |
+
if len(result_data.code_output_result) > 0:
|
| 283 |
+
output_files.extend(
|
| 284 |
+
[MediaFile(tos_path=image.content) for image in result_data.code_output_result
|
| 285 |
+
if image.type == 'image'])
|
| 286 |
+
|
| 287 |
+
return output_files
|
| 288 |
+
|
| 289 |
+
def _replace_csv_path(self, input_string):
|
| 290 |
+
# Search for the pattern and replace it
|
| 291 |
+
pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)'
|
| 292 |
+
replacement = "pd.read_csv('/path/to/your/dataset')"
|
| 293 |
+
updated_string = re.sub(pattern, replacement, input_string)
|
| 294 |
+
return updated_string
|
| 295 |
+
|
| 296 |
+
@staticmethod
|
| 297 |
+
def create_agent_response(formatted_output, output_files, raw_output):
|
| 298 |
+
return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output)
|
| 299 |
+
|
src/infiagent/conversation_sessions/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .code_interpreter_session import CodeInterpreterSession
|
src/infiagent/conversation_sessions/code_interpreter_session.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from typing import Any, Dict, Union
|
| 5 |
+
|
| 6 |
+
from werkzeug.datastructures import FileStorage
|
| 7 |
+
|
| 8 |
+
from ..agent import BaseAgent
|
| 9 |
+
from ..agent.react import AsyncReactAgent
|
| 10 |
+
from ..schemas import AgentRequest, MediaFile, Message, RoleType
|
| 11 |
+
from ..utils import generate_random_string, get_logger, get_model_config_path
|
| 12 |
+
|
| 13 |
+
logger = get_logger()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CodeInterpreterSession:
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
session_id: Union[None, str] = None,
|
| 20 |
+
model_name: Union[None, str] = "openai",
|
| 21 |
+
config_path: Union[None, str] = None,
|
| 22 |
+
agent: AsyncReactAgent = None,
|
| 23 |
+
**kwargs):
|
| 24 |
+
self.session_id = session_id
|
| 25 |
+
self.config_path = config_path
|
| 26 |
+
self.input_files = []
|
| 27 |
+
self.output_files = []
|
| 28 |
+
self.messages = []
|
| 29 |
+
self.agent = agent
|
| 30 |
+
self.llm_model_name = self.agent.llm.model_name
|
| 31 |
+
|
| 32 |
+
logger.info("Use model {} and llm in config {} for conversation {}"
|
| 33 |
+
.format(model_name, self.llm_model_name, self.config_path, self.session_id))
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
async def create(cls,
|
| 37 |
+
model_name: Union[None, str] = "openai",
|
| 38 |
+
config_path: Union[None, str] = None,
|
| 39 |
+
**kwargs: Dict[str, Any]):
|
| 40 |
+
if config_path is None:
|
| 41 |
+
config_path = get_model_config_path(model_name)
|
| 42 |
+
logger.info(f"Use Config Path: {config_path}")
|
| 43 |
+
|
| 44 |
+
sandbox_id = generate_random_string(12)
|
| 45 |
+
|
| 46 |
+
# setup agent
|
| 47 |
+
agent = await BaseAgent.async_from_config_path_and_kwargs(config_path, **kwargs)
|
| 48 |
+
await agent.plugins_map["python_code_sandbox"].set_sandbox_id(sandbox_id)
|
| 49 |
+
|
| 50 |
+
return cls(session_id=sandbox_id,
|
| 51 |
+
model_name=model_name,
|
| 52 |
+
config_path=config_path,
|
| 53 |
+
agent=agent)
|
| 54 |
+
|
| 55 |
+
async def upload_to_sandbox(self, file: Union[str, FileStorage]):
|
| 56 |
+
dst_path = await self.agent.sync_to_sandbox(file)
|
| 57 |
+
message = f'User uploaded the following files: {dst_path}\n'
|
| 58 |
+
logging.info(f"The file path {file} has been synced to sandbox with file path {dst_path}")
|
| 59 |
+
self.messages.append(Message(RoleType.System, message))
|
| 60 |
+
self.input_files.append(MediaFile(file_name=os.path.basename(dst_path), sandbox_path=dst_path))
|
| 61 |
+
|
| 62 |
+
async def chat(self, user_messages, input_files=None):
|
| 63 |
+
start_time = time.time()
|
| 64 |
+
|
| 65 |
+
self.messages.extend(user_messages)
|
| 66 |
+
agent_request = AgentRequest(
|
| 67 |
+
messages=self.messages,
|
| 68 |
+
input_files=self.input_files,
|
| 69 |
+
sandbox_id=self.session_id
|
| 70 |
+
)
|
| 71 |
+
logger.info(f"Agent request: {agent_request.__dict__}")
|
| 72 |
+
|
| 73 |
+
async for agent_response in self.agent.async_run(agent_request):
|
| 74 |
+
logger.info(f"Agent response:\n{agent_response.output_text}")
|
| 75 |
+
self.messages.append(Message(RoleType.System, agent_response.output_text))
|
| 76 |
+
yield agent_response
|
| 77 |
+
|
| 78 |
+
exec_time = time.time()
|
| 79 |
+
logger.info(
|
| 80 |
+
f'Agent Execution Latency: {exec_time - start_time}'
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def __enter__(self):
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
| 87 |
+
pass
|
src/infiagent/exceptions/__init__.py
ADDED
|
File without changes
|
src/infiagent/exceptions/exceptions.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class DependencyException(Exception):
|
| 2 |
+
pass
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class InputErrorException(Exception):
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class InternalErrorException(Exception):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DatabaseException(DependencyException):
|
| 14 |
+
def __init__(self, message, *args: object):
|
| 15 |
+
super().__init__(message, *args)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SandboxException(DependencyException):
|
| 19 |
+
def __init__(self, message, *args: object):
|
| 20 |
+
super().__init__(message, *args)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LLMException(DependencyException):
|
| 24 |
+
def __init__(self, message, *args: object):
|
| 25 |
+
super().__init__(message, *args)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ModelMaxIterationsException(DependencyException):
|
| 29 |
+
def __init__(self, message, *args: object):
|
| 30 |
+
super().__init__(message, *args)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class InvalidConfigException(InputErrorException):
|
| 34 |
+
def __init__(self, message, *args: object):
|
| 35 |
+
super().__init__(message, *args)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SandBoxFileUploadException(SandboxException):
|
| 39 |
+
def __init__(self, message, *args: object):
|
| 40 |
+
super().__init__(message, *args)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class PluginException(DependencyException):
|
| 44 |
+
def __init__(self, message, *args: object):
|
| 45 |
+
super().__init__(message, *args)
|
| 46 |
+
|
src/infiagent/llm/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .client.openai import *
|
| 2 |
+
from .client.azure_openai import *
|
| 3 |
+
from .client.opt import *
|
| 4 |
+
from .client.llama import *
|
| 5 |
+
from .base_llm import *
|
src/infiagent/llm/base_llm.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
|
| 3 |
+
from ..exceptions.exceptions import InputErrorException
|
| 4 |
+
from ..schemas import BaseCompletion
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseLLM(ABC):
|
| 8 |
+
|
| 9 |
+
def __init__(self, model_name: str, params: dict, **kwargs):
|
| 10 |
+
self.__model_name = model_name
|
| 11 |
+
self.__params = params
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
async def create(cls, config_data: dict):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def model_name(self) -> str:
|
| 19 |
+
return self.__model_name
|
| 20 |
+
|
| 21 |
+
@model_name.setter
|
| 22 |
+
def model_name(self, model_name):
|
| 23 |
+
if model_name is None:
|
| 24 |
+
raise InputErrorException("Invalid model_name {}".format(model_name))
|
| 25 |
+
self.__model_name = model_name
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def params(self) -> dict:
|
| 29 |
+
return self.__params
|
| 30 |
+
|
| 31 |
+
def completion(self, prompt) -> BaseCompletion:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
async def async_completion(self, prompt) -> BaseCompletion:
|
| 35 |
+
pass
|
| 36 |
+
|
src/infiagent/llm/client/__init__.py
ADDED
|
File without changes
|
src/infiagent/llm/client/azure_openai.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from abc import ABC
|
| 5 |
+
from typing import Callable, List
|
| 6 |
+
|
| 7 |
+
import openai
|
| 8 |
+
from tenacity import ( # for exponential backoff
|
| 9 |
+
before_sleep_log,
|
| 10 |
+
retry,
|
| 11 |
+
stop_after_attempt,
|
| 12 |
+
wait_random_exponential,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from ..base_llm import BaseLLM
|
| 16 |
+
from ...schemas import *
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
MAX_PROMPT_LENGTH = 7000
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(100), reraise=True,
|
| 24 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
| 25 |
+
def chatcompletion_with_backoff(**kwargs):
|
| 26 |
+
return openai.ChatCompletion.create(**kwargs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(100), reraise=True,
|
| 30 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
| 31 |
+
async def async_chatcompletion_with_backoff(**kwargs):
|
| 32 |
+
async def _internal_coroutine():
|
| 33 |
+
return await openai.ChatCompletion.acreate(**kwargs)
|
| 34 |
+
|
| 35 |
+
return await _internal_coroutine()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AzureOpenAIGPTClient(BaseLLM, ABC):
|
| 39 |
+
"""
|
| 40 |
+
Wrapper class for OpenAI GPT API collections.
|
| 41 |
+
|
| 42 |
+
:param model_name: The name of the model to use.
|
| 43 |
+
:type model_name: str
|
| 44 |
+
:param params: The parameters for the model.
|
| 45 |
+
:type params: AzureOpenAIParamModel
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
model_name: str
|
| 49 |
+
params: AzureOpenAIParamModel = AzureOpenAIParamModel()
|
| 50 |
+
|
| 51 |
+
def __init__(self, **data):
|
| 52 |
+
super().__init__(**data)
|
| 53 |
+
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
|
| 54 |
+
openai.api_type = "azure"
|
| 55 |
+
openai.api_base = "https://search.bytedance.net/gpt/openapi/online/v2/crawl"
|
| 56 |
+
openai.api_version = "2023-06-01-preview"
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
async def create(cls, config_data):
|
| 60 |
+
return AzureOpenAIGPTClient(**config_data)
|
| 61 |
+
|
| 62 |
+
def get_model_name(self) -> str:
|
| 63 |
+
return self.model_name
|
| 64 |
+
|
| 65 |
+
def get_model_param(self) -> AzureOpenAIParamModel:
|
| 66 |
+
return self.params
|
| 67 |
+
|
| 68 |
+
def completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 69 |
+
"""
|
| 70 |
+
Completion method for OpenAI GPT API.
|
| 71 |
+
|
| 72 |
+
:param prompt: The prompt to use for completion.
|
| 73 |
+
:type prompt: str
|
| 74 |
+
:param kwargs: Additional keyword arguments.
|
| 75 |
+
:type kwargs: dict
|
| 76 |
+
:return: BaseCompletion object.
|
| 77 |
+
:rtype: BaseCompletion
|
| 78 |
+
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
response = chatcompletion_with_backoff(
|
| 82 |
+
engine=self.get_model_name(), # GPT-4
|
| 83 |
+
messages=[
|
| 84 |
+
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
|
| 85 |
+
],
|
| 86 |
+
timeout=1000,
|
| 87 |
+
temperature=self.params.temperature,
|
| 88 |
+
max_tokens=self.params.max_tokens,
|
| 89 |
+
top_p=self.params.top_p,
|
| 90 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 91 |
+
presence_penalty=self.params.presence_penalty,
|
| 92 |
+
**kwargs
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return BaseCompletion(state="success",
|
| 96 |
+
content=response.choices[0].message["content"],
|
| 97 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 98 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 99 |
+
|
| 100 |
+
async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 101 |
+
"""
|
| 102 |
+
Completion method for OpenAI GPT API.
|
| 103 |
+
|
| 104 |
+
:param prompt: The prompt to use for completion.
|
| 105 |
+
:type prompt: str
|
| 106 |
+
:param kwargs: Additional keyword arguments.
|
| 107 |
+
:type kwargs: dict
|
| 108 |
+
:return: BaseCompletion object.
|
| 109 |
+
:rtype: BaseCompletion
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
response = await async_chatcompletion_with_backoff(
|
| 113 |
+
engine=self.get_model_name(),
|
| 114 |
+
messages=[
|
| 115 |
+
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
|
| 116 |
+
],
|
| 117 |
+
timeout=1000,
|
| 118 |
+
temperature=self.params.temperature,
|
| 119 |
+
max_tokens=self.params.max_tokens,
|
| 120 |
+
top_p=self.params.top_p,
|
| 121 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 122 |
+
presence_penalty=self.params.presence_penalty,
|
| 123 |
+
**kwargs
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return BaseCompletion(state="success",
|
| 127 |
+
content=response.choices[0].message["content"],
|
| 128 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 129 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 130 |
+
|
| 131 |
+
def chat_completion(self, message: List[dict]) -> ChatCompletion:
|
| 132 |
+
"""
|
| 133 |
+
Chat completion method for OpenAI GPT API.
|
| 134 |
+
|
| 135 |
+
:param message: The message to use for completion.
|
| 136 |
+
:type message: List[dict]
|
| 137 |
+
:return: ChatCompletion object.
|
| 138 |
+
:rtype: ChatCompletion
|
| 139 |
+
"""
|
| 140 |
+
try:
|
| 141 |
+
response = openai.ChatCompletion.create(
|
| 142 |
+
engine=self.get_model_name(), # GPT-4
|
| 143 |
+
messages=message,
|
| 144 |
+
timeout=1000,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return ChatCompletion(
|
| 148 |
+
state="success",
|
| 149 |
+
role=response.choices[0].message["role"],
|
| 150 |
+
content=response.choices[0].message["content"],
|
| 151 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 152 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0),
|
| 153 |
+
)
|
| 154 |
+
except Exception as exception:
|
| 155 |
+
print("Exception:", exception)
|
| 156 |
+
return ChatCompletion(state="error", content=exception)
|
| 157 |
+
|
| 158 |
+
def stream_chat_completion(self, message: List[dict], **kwargs):
|
| 159 |
+
"""
|
| 160 |
+
Stream output chat completion for OpenAI GPT API.
|
| 161 |
+
|
| 162 |
+
:param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
|
| 163 |
+
:type message: List[dict]
|
| 164 |
+
:param kwargs: Additional keyword arguments.
|
| 165 |
+
:type kwargs: dict
|
| 166 |
+
:return: ChatCompletion object.
|
| 167 |
+
:rtype: ChatCompletion
|
| 168 |
+
"""
|
| 169 |
+
try:
|
| 170 |
+
response = openai.ChatCompletion.create(
|
| 171 |
+
engine=self.get_model_name(), # GPT-4
|
| 172 |
+
messages=message,
|
| 173 |
+
timeout=1000,
|
| 174 |
+
**kwargs,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
role = next(response).choices[0].delta["role"]
|
| 178 |
+
messages = []
|
| 179 |
+
## TODO: Calculate prompt_token and for stream mode
|
| 180 |
+
for resp in response:
|
| 181 |
+
messages.append(resp.choices[0].delta.get("content", ""))
|
| 182 |
+
yield ChatCompletion(
|
| 183 |
+
state="success",
|
| 184 |
+
role=role,
|
| 185 |
+
content=messages[-1],
|
| 186 |
+
prompt_token=0,
|
| 187 |
+
completion_token=0,
|
| 188 |
+
)
|
| 189 |
+
except Exception as exception:
|
| 190 |
+
print("Exception:", exception)
|
| 191 |
+
return ChatCompletion(state="error", content=exception)
|
| 192 |
+
|
| 193 |
+
def function_chat_completion(
|
| 194 |
+
self,
|
| 195 |
+
message: List[dict],
|
| 196 |
+
function_map: Dict[str, Callable],
|
| 197 |
+
function_schema: List[Dict],
|
| 198 |
+
) -> ChatCompletionWithHistory:
|
| 199 |
+
"""
|
| 200 |
+
Chat completion method for OpenAI GPT API.
|
| 201 |
+
|
| 202 |
+
:param message: The message to use for completion.
|
| 203 |
+
:type message: List[dict]
|
| 204 |
+
:param function_map: The function map to use for completion.
|
| 205 |
+
:type function_map: Dict[str, Callable]
|
| 206 |
+
:param function_schema: The function schema to use for completion.
|
| 207 |
+
:type function_schema: List[Dict]
|
| 208 |
+
:return: ChatCompletionWithHistory object.
|
| 209 |
+
:rtype: ChatCompletionWithHistory
|
| 210 |
+
"""
|
| 211 |
+
assert len(function_schema) == len(function_map)
|
| 212 |
+
try:
|
| 213 |
+
response = openai.ChatCompletion.create(
|
| 214 |
+
engine=self.get_model_name(), # GPT-4
|
| 215 |
+
messages=message,
|
| 216 |
+
functions=function_schema,
|
| 217 |
+
timeout=1000,
|
| 218 |
+
)
|
| 219 |
+
# response = openai.ChatCompletion.create(
|
| 220 |
+
# n=self.params.n,
|
| 221 |
+
# model=self.model_name,
|
| 222 |
+
# messages=message,
|
| 223 |
+
# functions=function_schema,
|
| 224 |
+
# temperature=self.params.temperature,
|
| 225 |
+
# max_tokens=self.params.max_tokens,
|
| 226 |
+
# top_p=self.params.top_p,
|
| 227 |
+
# frequency_penalty=self.params.frequency_penalty,
|
| 228 |
+
# presence_penalty=self.params.presence_penalty,
|
| 229 |
+
# )
|
| 230 |
+
response_message = response.choices[0]["message"]
|
| 231 |
+
|
| 232 |
+
if response_message.get("function_call"):
|
| 233 |
+
function_name = response_message["function_call"]["name"]
|
| 234 |
+
fuction_to_call = function_map[function_name]
|
| 235 |
+
function_args = json.loads(
|
| 236 |
+
response_message["function_call"]["arguments"]
|
| 237 |
+
)
|
| 238 |
+
function_response = fuction_to_call(**function_args)
|
| 239 |
+
|
| 240 |
+
# Postprocess function response
|
| 241 |
+
if isinstance(function_response, str):
|
| 242 |
+
plugin_cost = 0
|
| 243 |
+
plugin_token = 0
|
| 244 |
+
elif isinstance(function_response, AgentOutput):
|
| 245 |
+
plugin_cost = function_response.cost
|
| 246 |
+
plugin_token = function_response.token_usage
|
| 247 |
+
function_response = function_response.output
|
| 248 |
+
else:
|
| 249 |
+
raise Exception(
|
| 250 |
+
"Invalid tool response type. Must be on of [AgentOutput, str]"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
message.append(dict(response_message))
|
| 254 |
+
message.append(
|
| 255 |
+
{
|
| 256 |
+
"role": "function",
|
| 257 |
+
"name": function_name,
|
| 258 |
+
"content": function_response,
|
| 259 |
+
}
|
| 260 |
+
)
|
| 261 |
+
second_response = openai.ChatCompletion.create(
|
| 262 |
+
model=self.get_model_name(),
|
| 263 |
+
messages=message,
|
| 264 |
+
)
|
| 265 |
+
message.append(dict(second_response.choices[0].message))
|
| 266 |
+
return ChatCompletionWithHistory(
|
| 267 |
+
state="success",
|
| 268 |
+
role=second_response.choices[0].message["role"],
|
| 269 |
+
content=second_response.choices[0].message["content"],
|
| 270 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0)
|
| 271 |
+
+ second_response.get("usage", {}).get("prompt_tokens", 0),
|
| 272 |
+
completion_token=response.get("usage", {}).get(
|
| 273 |
+
"completion_tokens", 0
|
| 274 |
+
)
|
| 275 |
+
+ second_response.get("usage", {}).get("completion_tokens", 0),
|
| 276 |
+
message_scratchpad=message,
|
| 277 |
+
plugin_cost=plugin_cost,
|
| 278 |
+
plugin_token=plugin_token,
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
message.append(dict(response_message))
|
| 282 |
+
return ChatCompletionWithHistory(
|
| 283 |
+
state="success",
|
| 284 |
+
role=response.choices[0].message["role"],
|
| 285 |
+
content=response.choices[0].message["content"],
|
| 286 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 287 |
+
completion_token=response.get("usage", {}).get(
|
| 288 |
+
"completion_tokens", 0
|
| 289 |
+
),
|
| 290 |
+
message_scratchpad=message,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
except Exception as exception:
|
| 294 |
+
print("Exception:", exception)
|
| 295 |
+
return ChatCompletionWithHistory(state="error", content=str(exception))
|
| 296 |
+
|
| 297 |
+
def function_chat_stream_completion(
|
| 298 |
+
self,
|
| 299 |
+
message: List[dict],
|
| 300 |
+
function_map: Dict[str, Callable],
|
| 301 |
+
function_schema: List[Dict],
|
| 302 |
+
) -> ChatCompletionWithHistory:
|
| 303 |
+
assert len(function_schema) == len(function_map)
|
| 304 |
+
try:
|
| 305 |
+
response = openai.ChatCompletion.create(
|
| 306 |
+
n=self.params.n,
|
| 307 |
+
model=self.get_model_name(),
|
| 308 |
+
messages=message,
|
| 309 |
+
functions=function_schema,
|
| 310 |
+
temperature=self.params.temperature,
|
| 311 |
+
max_tokens=self.params.max_tokens,
|
| 312 |
+
top_p=self.params.top_p,
|
| 313 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 314 |
+
presence_penalty=self.params.presence_penalty,
|
| 315 |
+
stream=True,
|
| 316 |
+
)
|
| 317 |
+
tmp = next(response)
|
| 318 |
+
role = tmp.choices[0].delta["role"]
|
| 319 |
+
_type = (
|
| 320 |
+
"function_call"
|
| 321 |
+
if tmp.choices[0].delta["content"] is None
|
| 322 |
+
else "content"
|
| 323 |
+
)
|
| 324 |
+
if _type == "function_call":
|
| 325 |
+
name = tmp.choices[0].delta["function_call"]["name"]
|
| 326 |
+
yield _type, ChatCompletionWithHistory(
|
| 327 |
+
state="success",
|
| 328 |
+
role=role,
|
| 329 |
+
content="{" + f'"name":"{name}", "arguments":',
|
| 330 |
+
message_scratchpad=message,
|
| 331 |
+
)
|
| 332 |
+
for resp in response:
|
| 333 |
+
# print(resp)
|
| 334 |
+
content = resp.choices[0].delta.get(_type, "")
|
| 335 |
+
if isinstance(content, dict):
|
| 336 |
+
content = content["arguments"]
|
| 337 |
+
yield _type, ChatCompletionWithHistory(
|
| 338 |
+
state="success",
|
| 339 |
+
role=role,
|
| 340 |
+
content=content,
|
| 341 |
+
message_scratchpad=message,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
except Exception as e:
|
| 345 |
+
logger.error(f"Failed to get response {str(e)}", exc_info=True)
|
| 346 |
+
raise e
|
src/infiagent/llm/client/llama.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from abc import ABC
|
| 5 |
+
from typing import Callable, List
|
| 6 |
+
|
| 7 |
+
import openai
|
| 8 |
+
from tenacity import ( # for exponential backoff
|
| 9 |
+
before_sleep_log,
|
| 10 |
+
retry,
|
| 11 |
+
stop_after_attempt,
|
| 12 |
+
wait_random_exponential,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from ..base_llm import BaseLLM
|
| 16 |
+
from ...schemas import *
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
MAX_PROMPT_LENGTH = 4096
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(5), reraise=True,
|
| 24 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
| 25 |
+
def chatcompletion_with_backoff(**kwargs):
|
| 26 |
+
return openai.ChatCompletion.create(**kwargs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(5), reraise=True,
|
| 30 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
| 31 |
+
async def async_chatcompletion_with_backoff(**kwargs):
|
| 32 |
+
async def _internal_coroutine():
|
| 33 |
+
return await openai.ChatCompletion.acreate(**kwargs)
|
| 34 |
+
|
| 35 |
+
return await _internal_coroutine()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LlamaOpenAIClient(BaseLLM, ABC):
|
| 39 |
+
"""
|
| 40 |
+
Wrapper class for OpenAI GPT API collections.
|
| 41 |
+
|
| 42 |
+
:param model_name: The name of the model to use.
|
| 43 |
+
:type model_name: str
|
| 44 |
+
:param params: The parameters for the model.
|
| 45 |
+
:type params: LlamaParamModel
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
model_name: str
|
| 49 |
+
params: LlamaParamModel = LlamaParamModel()
|
| 50 |
+
|
| 51 |
+
def __init__(self, **data):
|
| 52 |
+
super().__init__(**data)
|
| 53 |
+
openai.api_key = ""
|
| 54 |
+
openai.api_base = "http://0.0.0.0:9729/v1"
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
async def create(cls, config_data):
|
| 58 |
+
return LlamaOpenAIClient(**config_data)
|
| 59 |
+
|
| 60 |
+
def get_model_name(self) -> str:
|
| 61 |
+
return self.model_name
|
| 62 |
+
|
| 63 |
+
def get_model_param(self) -> LlamaParamModel:
|
| 64 |
+
return self.params
|
| 65 |
+
|
| 66 |
+
def completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 67 |
+
"""
|
| 68 |
+
Completion method for OpenAI GPT API.
|
| 69 |
+
|
| 70 |
+
:param prompt: The prompt to use for completion.
|
| 71 |
+
:type prompt: str
|
| 72 |
+
:param kwargs: Additional keyword arguments.
|
| 73 |
+
:type kwargs: dict
|
| 74 |
+
:return: BaseCompletion object.
|
| 75 |
+
:rtype: BaseCompletion
|
| 76 |
+
"""
|
| 77 |
+
response = chatcompletion_with_backoff(
|
| 78 |
+
model=self.model_name,
|
| 79 |
+
messages=[
|
| 80 |
+
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
|
| 81 |
+
],
|
| 82 |
+
timeout=1000,
|
| 83 |
+
# temperature=self.params.temperature,
|
| 84 |
+
# max_tokens=self.params.max_tokens,
|
| 85 |
+
# top_p=self.params.top_p,
|
| 86 |
+
# frequency_penalty=self.params.frequency_penalty,
|
| 87 |
+
# presence_penalty=self.params.presence_penalty,
|
| 88 |
+
# stop=["<|im_end|>", "<|endoftext|>"],
|
| 89 |
+
**kwargs
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return BaseCompletion(state="success",
|
| 93 |
+
content=response.choices[0].message["content"],
|
| 94 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 95 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 96 |
+
|
| 97 |
+
async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 98 |
+
"""
|
| 99 |
+
Completion method for OpenAI GPT API.
|
| 100 |
+
|
| 101 |
+
:param prompt: The prompt to use for completion.
|
| 102 |
+
:type prompt: str
|
| 103 |
+
:param kwargs: Additional keyword arguments.
|
| 104 |
+
:type kwargs: dict
|
| 105 |
+
:return: BaseCompletion object.
|
| 106 |
+
:rtype: BaseCompletion
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
response = await async_chatcompletion_with_backoff(
|
| 110 |
+
model=self.model_name,
|
| 111 |
+
messages=[
|
| 112 |
+
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
|
| 113 |
+
],
|
| 114 |
+
timeout=1000,
|
| 115 |
+
#temperature=0.2,
|
| 116 |
+
#max_tokens=4096,
|
| 117 |
+
#top_p=0.9,
|
| 118 |
+
#frequency_penalty=self.params.frequency_penalty,
|
| 119 |
+
#presence_penalty=self.params.presence_penalty,
|
| 120 |
+
# stop=["<|im_end|>", "<|endoftext|>"],
|
| 121 |
+
**kwargs
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return BaseCompletion(state="success",
|
| 125 |
+
content=response.choices[0].message["content"],
|
| 126 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 127 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 128 |
+
|
| 129 |
+
def chat_completion(self, message: List[dict]) -> ChatCompletion:
|
| 130 |
+
"""
|
| 131 |
+
Chat completion method for OpenAI GPT API.
|
| 132 |
+
|
| 133 |
+
:param message: The message to use for completion.
|
| 134 |
+
:type message: List[dict]
|
| 135 |
+
:return: ChatCompletion object.
|
| 136 |
+
:rtype: ChatCompletion
|
| 137 |
+
"""
|
| 138 |
+
try:
|
| 139 |
+
response = openai.ChatCompletion.create(
|
| 140 |
+
n=self.params.n,
|
| 141 |
+
model=self.model_name,
|
| 142 |
+
timeout=1000,
|
| 143 |
+
messages=message,
|
| 144 |
+
temperature=self.params.temperature,
|
| 145 |
+
max_tokens=self.params.max_tokens,
|
| 146 |
+
top_p=self.params.top_p,
|
| 147 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 148 |
+
presence_penalty=self.params.presence_penalty,
|
| 149 |
+
)
|
| 150 |
+
return ChatCompletion(
|
| 151 |
+
state="success",
|
| 152 |
+
role=response.choices[0].message["role"],
|
| 153 |
+
content=response.choices[0].message["content"],
|
| 154 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 155 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0),
|
| 156 |
+
)
|
| 157 |
+
except Exception as exception:
|
| 158 |
+
print("Exception:", exception)
|
| 159 |
+
return ChatCompletion(state="error", content=exception)
|
| 160 |
+
|
| 161 |
+
def stream_chat_completion(self, message: List[dict], **kwargs):
|
| 162 |
+
"""
|
| 163 |
+
Stream output chat completion for OpenAI GPT API.
|
| 164 |
+
|
| 165 |
+
:param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
|
| 166 |
+
:type message: List[dict]
|
| 167 |
+
:param kwargs: Additional keyword arguments.
|
| 168 |
+
:type kwargs: dict
|
| 169 |
+
:return: ChatCompletion object.
|
| 170 |
+
:rtype: ChatCompletion
|
| 171 |
+
"""
|
| 172 |
+
try:
|
| 173 |
+
# response = openai.ChatCompletion.create(
|
| 174 |
+
# engine=self.get_model_name(), # GPT-4
|
| 175 |
+
# messages=message,
|
| 176 |
+
# timeout=1000,
|
| 177 |
+
# **kwargs,
|
| 178 |
+
# )
|
| 179 |
+
response = openai.ChatCompletion.create(
|
| 180 |
+
n=self.params.n,
|
| 181 |
+
model=self.model_name,
|
| 182 |
+
messages=message,
|
| 183 |
+
temperature=self.params.temperature,
|
| 184 |
+
max_tokens=self.params.max_tokens,
|
| 185 |
+
top_p=self.params.top_p,
|
| 186 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 187 |
+
presence_penalty=self.params.presence_penalty,
|
| 188 |
+
stream=True,
|
| 189 |
+
**kwargs
|
| 190 |
+
)
|
| 191 |
+
role = next(response).choices[0].delta["role"]
|
| 192 |
+
messages = []
|
| 193 |
+
## TODO: Calculate prompt_token and for stream mode
|
| 194 |
+
for resp in response:
|
| 195 |
+
messages.append(resp.choices[0].delta.get("content", ""))
|
| 196 |
+
yield ChatCompletion(
|
| 197 |
+
state="success",
|
| 198 |
+
role=role,
|
| 199 |
+
content=messages[-1],
|
| 200 |
+
prompt_token=0,
|
| 201 |
+
completion_token=0,
|
| 202 |
+
)
|
| 203 |
+
except Exception as exception:
|
| 204 |
+
print("Exception:", exception)
|
| 205 |
+
return ChatCompletion(state="error", content=exception)
|
| 206 |
+
|
| 207 |
+
def function_chat_completion(
|
| 208 |
+
self,
|
| 209 |
+
message: List[dict],
|
| 210 |
+
function_map: Dict[str, Callable],
|
| 211 |
+
function_schema: List[Dict],
|
| 212 |
+
) -> ChatCompletionWithHistory:
|
| 213 |
+
"""
|
| 214 |
+
Chat completion method for OpenAI GPT API.
|
| 215 |
+
|
| 216 |
+
:param message: The message to use for completion.
|
| 217 |
+
:type message: List[dict]
|
| 218 |
+
:param function_map: The function map to use for completion.
|
| 219 |
+
:type function_map: Dict[str, Callable]
|
| 220 |
+
:param function_schema: The function schema to use for completion.
|
| 221 |
+
:type function_schema: List[Dict]
|
| 222 |
+
:return: ChatCompletionWithHistory object.
|
| 223 |
+
:rtype: ChatCompletionWithHistory
|
| 224 |
+
"""
|
| 225 |
+
assert len(function_schema) == len(function_map)
|
| 226 |
+
try:
|
| 227 |
+
# response = openai.ChatCompletion.create(
|
| 228 |
+
# engine=self.get_model_name(), # GPT-4
|
| 229 |
+
# messages=message,
|
| 230 |
+
# functions=function_schema,
|
| 231 |
+
# timeout=1000,
|
| 232 |
+
# )
|
| 233 |
+
response = openai.ChatCompletion.create(
|
| 234 |
+
n=self.params.n,
|
| 235 |
+
model=self.model_name,
|
| 236 |
+
messages=message,
|
| 237 |
+
functions=function_schema,
|
| 238 |
+
temperature=self.params.temperature,
|
| 239 |
+
max_tokens=self.params.max_tokens,
|
| 240 |
+
top_p=self.params.top_p,
|
| 241 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 242 |
+
presence_penalty=self.params.presence_penalty,
|
| 243 |
+
)
|
| 244 |
+
response_message = response.choices[0]["message"]
|
| 245 |
+
|
| 246 |
+
if response_message.get("function_call"):
|
| 247 |
+
function_name = response_message["function_call"]["name"]
|
| 248 |
+
fuction_to_call = function_map[function_name]
|
| 249 |
+
function_args = json.loads(
|
| 250 |
+
response_message["function_call"]["arguments"]
|
| 251 |
+
)
|
| 252 |
+
function_response = fuction_to_call(**function_args)
|
| 253 |
+
|
| 254 |
+
# Postprocess function response
|
| 255 |
+
if isinstance(function_response, str):
|
| 256 |
+
plugin_cost = 0
|
| 257 |
+
plugin_token = 0
|
| 258 |
+
elif isinstance(function_response, AgentOutput):
|
| 259 |
+
plugin_cost = function_response.cost
|
| 260 |
+
plugin_token = function_response.token_usage
|
| 261 |
+
function_response = function_response.output
|
| 262 |
+
else:
|
| 263 |
+
raise Exception(
|
| 264 |
+
"Invalid tool response type. Must be on of [AgentOutput, str]"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
message.append(dict(response_message))
|
| 268 |
+
message.append(
|
| 269 |
+
{
|
| 270 |
+
"role": "function",
|
| 271 |
+
"name": function_name,
|
| 272 |
+
"content": function_response,
|
| 273 |
+
}
|
| 274 |
+
)
|
| 275 |
+
second_response = openai.ChatCompletion.create(
|
| 276 |
+
model=self.get_model_name(),
|
| 277 |
+
messages=message,
|
| 278 |
+
)
|
| 279 |
+
message.append(dict(second_response.choices[0].message))
|
| 280 |
+
return ChatCompletionWithHistory(
|
| 281 |
+
state="success",
|
| 282 |
+
role=second_response.choices[0].message["role"],
|
| 283 |
+
content=second_response.choices[0].message["content"],
|
| 284 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0)
|
| 285 |
+
+ second_response.get("usage", {}).get("prompt_tokens", 0),
|
| 286 |
+
completion_token=response.get("usage", {}).get(
|
| 287 |
+
"completion_tokens", 0
|
| 288 |
+
)
|
| 289 |
+
+ second_response.get("usage", {}).get("completion_tokens", 0),
|
| 290 |
+
message_scratchpad=message,
|
| 291 |
+
plugin_cost=plugin_cost,
|
| 292 |
+
plugin_token=plugin_token,
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
message.append(dict(response_message))
|
| 296 |
+
return ChatCompletionWithHistory(
|
| 297 |
+
state="success",
|
| 298 |
+
role=response.choices[0].message["role"],
|
| 299 |
+
content=response.choices[0].message["content"],
|
| 300 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 301 |
+
completion_token=response.get("usage", {}).get(
|
| 302 |
+
"completion_tokens", 0
|
| 303 |
+
),
|
| 304 |
+
message_scratchpad=message,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
except Exception as exception:
|
| 308 |
+
print("Exception:", exception)
|
| 309 |
+
return ChatCompletionWithHistory(state="error", content=str(exception))
|
| 310 |
+
|
| 311 |
+
def function_chat_stream_completion(
|
| 312 |
+
self,
|
| 313 |
+
message: List[dict],
|
| 314 |
+
function_map: Dict[str, Callable],
|
| 315 |
+
function_schema: List[Dict],
|
| 316 |
+
) -> ChatCompletionWithHistory:
|
| 317 |
+
assert len(function_schema) == len(function_map)
|
| 318 |
+
try:
|
| 319 |
+
response = openai.ChatCompletion.create(
|
| 320 |
+
n=self.params.n,
|
| 321 |
+
model=self.get_model_name(),
|
| 322 |
+
messages=message,
|
| 323 |
+
functions=function_schema,
|
| 324 |
+
temperature=self.params.temperature,
|
| 325 |
+
max_tokens=self.params.max_tokens,
|
| 326 |
+
top_p=self.params.top_p,
|
| 327 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 328 |
+
presence_penalty=self.params.presence_penalty,
|
| 329 |
+
stream=True,
|
| 330 |
+
)
|
| 331 |
+
tmp = next(response)
|
| 332 |
+
role = tmp.choices[0].delta["role"]
|
| 333 |
+
_type = (
|
| 334 |
+
"function_call"
|
| 335 |
+
if tmp.choices[0].delta["content"] is None
|
| 336 |
+
else "content"
|
| 337 |
+
)
|
| 338 |
+
if _type == "function_call":
|
| 339 |
+
name = tmp.choices[0].delta["function_call"]["name"]
|
| 340 |
+
yield _type, ChatCompletionWithHistory(
|
| 341 |
+
state="success",
|
| 342 |
+
role=role,
|
| 343 |
+
content="{" + f'"name":"{name}", "arguments":',
|
| 344 |
+
message_scratchpad=message,
|
| 345 |
+
)
|
| 346 |
+
for resp in response:
|
| 347 |
+
# print(resp)
|
| 348 |
+
content = resp.choices[0].delta.get(_type, "")
|
| 349 |
+
if isinstance(content, dict):
|
| 350 |
+
content = content["arguments"]
|
| 351 |
+
yield _type, ChatCompletionWithHistory(
|
| 352 |
+
state="success",
|
| 353 |
+
role=role,
|
| 354 |
+
content=content,
|
| 355 |
+
message_scratchpad=message,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# result = ''.join(messages)
|
| 359 |
+
# if _type == "function_call":
|
| 360 |
+
# result = json.loads(result)
|
| 361 |
+
# function_name = result["name"]
|
| 362 |
+
# fuction_to_call = function_map[function_name]
|
| 363 |
+
# function_args = result["arguments"]
|
| 364 |
+
# function_response = fuction_to_call(**function_args)
|
| 365 |
+
#
|
| 366 |
+
# # Postprocess function response
|
| 367 |
+
# if isinstance(function_response, AgentOutput):
|
| 368 |
+
# function_response = function_response.output
|
| 369 |
+
# message.append({"role": "function",
|
| 370 |
+
# "name": function_name,
|
| 371 |
+
# "content": function_response})
|
| 372 |
+
# second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema)
|
| 373 |
+
# message.append(dict(second_response.choices[0].message))
|
| 374 |
+
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.error(f"Failed to get response {str(e)}", exc_info=True)
|
| 377 |
+
raise e
|
src/infiagent/llm/client/openai.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from abc import ABC
|
| 4 |
+
from typing import Callable, List
|
| 5 |
+
|
| 6 |
+
import openai
|
| 7 |
+
|
| 8 |
+
from ..base_llm import BaseLLM
|
| 9 |
+
from ...schemas import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OpenAIGPTClient(BaseLLM, ABC):
|
| 13 |
+
"""
|
| 14 |
+
Wrapper class for OpenAI GPT API collections.
|
| 15 |
+
|
| 16 |
+
:param model_name: The name of the model to use.
|
| 17 |
+
:type model_name: str
|
| 18 |
+
:param params: The parameters for the model.
|
| 19 |
+
:type params: OpenAIParamModel
|
| 20 |
+
"""
|
| 21 |
+
model_name: str
|
| 22 |
+
params: OpenAIParamModel = OpenAIParamModel()
|
| 23 |
+
|
| 24 |
+
def __init__(self, **data):
|
| 25 |
+
super().__init__(**data)
|
| 26 |
+
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
async def create(cls, config_data):
|
| 30 |
+
return OpenAIGPTClient(**config_data)
|
| 31 |
+
|
| 32 |
+
def get_model_name(self) -> str:
|
| 33 |
+
return self.model_name
|
| 34 |
+
|
| 35 |
+
def get_model_param(self) -> OpenAIParamModel:
|
| 36 |
+
return self.params
|
| 37 |
+
|
| 38 |
+
def completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 39 |
+
"""
|
| 40 |
+
Completion method for OpenAI GPT API.
|
| 41 |
+
|
| 42 |
+
:param prompt: The prompt to use for completion.
|
| 43 |
+
:type prompt: str
|
| 44 |
+
:param kwargs: Additional keyword arguments.
|
| 45 |
+
:type kwargs: dict
|
| 46 |
+
:return: BaseCompletion object.
|
| 47 |
+
:rtype: BaseCompletion
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
try:
|
| 51 |
+
#TODO any full parameters support
|
| 52 |
+
response = openai.ChatCompletion.create(
|
| 53 |
+
# n=self.params['n'],
|
| 54 |
+
engine=self.model_name,
|
| 55 |
+
messages=[{"role": "user", "content": prompt}],
|
| 56 |
+
temperature=self.params['temperature'],
|
| 57 |
+
max_tokens=self.params['max_tokens'],
|
| 58 |
+
top_p=self.params['top_p'],
|
| 59 |
+
# frequency_penalty=self.params.frequency_penalty,
|
| 60 |
+
# presence_penalty=self.params.presence_penalty,
|
| 61 |
+
**kwargs
|
| 62 |
+
)
|
| 63 |
+
return BaseCompletion(state="success",
|
| 64 |
+
content=response.choices[0].message["content"],
|
| 65 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 66 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 67 |
+
except Exception as exception:
|
| 68 |
+
print("Exception:", exception)
|
| 69 |
+
return BaseCompletion(state="error", content=exception)
|
| 70 |
+
|
| 71 |
+
async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 72 |
+
"""
|
| 73 |
+
Async Completion method for OpenAI GPT API.
|
| 74 |
+
|
| 75 |
+
:param prompt: The prompt to use for completion.
|
| 76 |
+
:type prompt: str
|
| 77 |
+
:param kwargs: Additional keyword arguments.
|
| 78 |
+
:type kwargs: dict
|
| 79 |
+
:return: BaseCompletion object.
|
| 80 |
+
:rtype: BaseCompletion
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
response = await openai.ChatCompletion.acreate(
|
| 85 |
+
model=self.model_name,
|
| 86 |
+
messages=[{"role": "user", "content": prompt}],
|
| 87 |
+
temperature=self.params['temperature'],
|
| 88 |
+
max_tokens=self.params['max_tokens'],
|
| 89 |
+
top_p=self.params['top_p'],
|
| 90 |
+
# frequency_penalty=self.params.frequency_penalty,
|
| 91 |
+
# presence_penalty=self.params.presence_penalty,
|
| 92 |
+
**kwargs
|
| 93 |
+
)
|
| 94 |
+
return BaseCompletion(state="success",
|
| 95 |
+
content=response.choices[0].message["content"],
|
| 96 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 97 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 98 |
+
except Exception as exception:
|
| 99 |
+
print("Exception:", exception)
|
| 100 |
+
return BaseCompletion(state="error", content=exception)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def chat_completion(self, message: List[dict]) -> ChatCompletion:
|
| 104 |
+
"""
|
| 105 |
+
Chat completion method for OpenAI GPT API.
|
| 106 |
+
|
| 107 |
+
:param message: The message to use for completion.
|
| 108 |
+
:type message: List[dict]
|
| 109 |
+
:return: ChatCompletion object.
|
| 110 |
+
:rtype: ChatCompletion
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
response = openai.ChatCompletion.create(
|
| 114 |
+
n=self.params.n,
|
| 115 |
+
model=self.model_name,
|
| 116 |
+
messages=message,
|
| 117 |
+
temperature=self.params.temperature,
|
| 118 |
+
max_tokens=self.params.max_tokens,
|
| 119 |
+
top_p=self.params.top_p,
|
| 120 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 121 |
+
presence_penalty=self.params.presence_penalty,
|
| 122 |
+
)
|
| 123 |
+
return ChatCompletion(state="success",
|
| 124 |
+
role=response.choices[0].message["role"],
|
| 125 |
+
content=response.choices[0].message["content"],
|
| 126 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 127 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 128 |
+
except Exception as exception:
|
| 129 |
+
print("Exception:", exception)
|
| 130 |
+
return ChatCompletion(state="error", content=exception)
|
| 131 |
+
|
| 132 |
+
def stream_chat_completion(self, message: List[dict], **kwargs):
|
| 133 |
+
"""
|
| 134 |
+
Stream output chat completion for OpenAI GPT API.
|
| 135 |
+
|
| 136 |
+
:param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
|
| 137 |
+
:type message: List[dict]
|
| 138 |
+
:param kwargs: Additional keyword arguments.
|
| 139 |
+
:type kwargs: dict
|
| 140 |
+
:return: ChatCompletion object.
|
| 141 |
+
:rtype: ChatCompletion
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
response = openai.ChatCompletion.create(
|
| 145 |
+
n=self.params.n,
|
| 146 |
+
model=self.model_name,
|
| 147 |
+
messages=message,
|
| 148 |
+
temperature=self.params.temperature,
|
| 149 |
+
max_tokens=self.params.max_tokens,
|
| 150 |
+
top_p=self.params.top_p,
|
| 151 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 152 |
+
presence_penalty=self.params.presence_penalty,
|
| 153 |
+
stream=True,
|
| 154 |
+
**kwargs
|
| 155 |
+
)
|
| 156 |
+
role = next(response).choices[0].delta["role"]
|
| 157 |
+
messages = []
|
| 158 |
+
## TODO: Calculate prompt_token and for stream mode
|
| 159 |
+
for resp in response:
|
| 160 |
+
messages.append(resp.choices[0].delta.get("content", ""))
|
| 161 |
+
yield ChatCompletion(state="success",
|
| 162 |
+
role=role,
|
| 163 |
+
content=messages[-1],
|
| 164 |
+
prompt_token=0,
|
| 165 |
+
completion_token=0)
|
| 166 |
+
except Exception as exception:
|
| 167 |
+
print("Exception:", exception)
|
| 168 |
+
return ChatCompletion(state="error", content=exception)
|
| 169 |
+
|
| 170 |
+
def function_chat_completion(self, message: List[dict],
|
| 171 |
+
function_map: Dict[str, Callable],
|
| 172 |
+
function_schema: List[Dict]) -> ChatCompletionWithHistory:
|
| 173 |
+
"""
|
| 174 |
+
Chat completion method for OpenAI GPT API.
|
| 175 |
+
|
| 176 |
+
:param message: The message to use for completion.
|
| 177 |
+
:type message: List[dict]
|
| 178 |
+
:param function_map: The function map to use for completion.
|
| 179 |
+
:type function_map: Dict[str, Callable]
|
| 180 |
+
:param function_schema: The function schema to use for completion.
|
| 181 |
+
:type function_schema: List[Dict]
|
| 182 |
+
:return: ChatCompletionWithHistory object.
|
| 183 |
+
:rtype: ChatCompletionWithHistory
|
| 184 |
+
"""
|
| 185 |
+
assert len(function_schema) == len(function_map)
|
| 186 |
+
try:
|
| 187 |
+
response = openai.ChatCompletion.create(
|
| 188 |
+
n=self.params.n,
|
| 189 |
+
model=self.model_name,
|
| 190 |
+
messages=message,
|
| 191 |
+
functions=function_schema,
|
| 192 |
+
temperature=self.params.temperature,
|
| 193 |
+
max_tokens=self.params.max_tokens,
|
| 194 |
+
top_p=self.params.top_p,
|
| 195 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 196 |
+
presence_penalty=self.params.presence_penalty,
|
| 197 |
+
)
|
| 198 |
+
response_message = response.choices[0]["message"]
|
| 199 |
+
|
| 200 |
+
if response_message.get("function_call"):
|
| 201 |
+
function_name = response_message["function_call"]["name"]
|
| 202 |
+
fuction_to_call = function_map[function_name]
|
| 203 |
+
function_args = json.loads(response_message["function_call"]["arguments"])
|
| 204 |
+
function_response = fuction_to_call(**function_args)
|
| 205 |
+
|
| 206 |
+
# Postprocess function response
|
| 207 |
+
if isinstance(function_response, str):
|
| 208 |
+
plugin_cost = 0
|
| 209 |
+
plugin_token = 0
|
| 210 |
+
elif isinstance(function_response, AgentOutput):
|
| 211 |
+
plugin_cost = function_response.cost
|
| 212 |
+
plugin_token = function_response.token_usage
|
| 213 |
+
function_response = function_response.output
|
| 214 |
+
else:
|
| 215 |
+
raise Exception("Invalid tool response type. Must be on of [AgentOutput, str]")
|
| 216 |
+
|
| 217 |
+
message.append(dict(response_message))
|
| 218 |
+
message.append({"role": "function",
|
| 219 |
+
"name": function_name,
|
| 220 |
+
"content": function_response})
|
| 221 |
+
second_response = openai.ChatCompletion.create(
|
| 222 |
+
model=self.model_name,
|
| 223 |
+
messages=message,
|
| 224 |
+
)
|
| 225 |
+
message.append(dict(second_response.choices[0].message))
|
| 226 |
+
return ChatCompletionWithHistory(state="success",
|
| 227 |
+
role=second_response.choices[0].message["role"],
|
| 228 |
+
content=second_response.choices[0].message["content"],
|
| 229 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0) +
|
| 230 |
+
second_response.get("usage", {}).get("prompt_tokens", 0),
|
| 231 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0) +
|
| 232 |
+
second_response.get("usage", {}).get("completion_tokens", 0),
|
| 233 |
+
message_scratchpad=message,
|
| 234 |
+
plugin_cost=plugin_cost,
|
| 235 |
+
plugin_token=plugin_token,
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
message.append(dict(response_message))
|
| 239 |
+
return ChatCompletionWithHistory(state="success",
|
| 240 |
+
role=response.choices[0].message["role"],
|
| 241 |
+
content=response.choices[0].message["content"],
|
| 242 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 243 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0),
|
| 244 |
+
message_scratchpad=message)
|
| 245 |
+
|
| 246 |
+
except Exception as exception:
|
| 247 |
+
print("Exception:", exception)
|
| 248 |
+
return ChatCompletionWithHistory(state="error", content=str(exception))
|
| 249 |
+
|
| 250 |
+
def function_chat_stream_completion(self, message: List[dict],
|
| 251 |
+
function_map: Dict[str, Callable],
|
| 252 |
+
function_schema: List[Dict]) -> ChatCompletionWithHistory:
|
| 253 |
+
assert len(function_schema) == len(function_map)
|
| 254 |
+
try:
|
| 255 |
+
response = openai.ChatCompletion.create(
|
| 256 |
+
n=self.params.n,
|
| 257 |
+
model=self.model_name,
|
| 258 |
+
messages=message,
|
| 259 |
+
functions=function_schema,
|
| 260 |
+
temperature=self.params.temperature,
|
| 261 |
+
max_tokens=self.params.max_tokens,
|
| 262 |
+
top_p=self.params.top_p,
|
| 263 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 264 |
+
presence_penalty=self.params.presence_penalty,
|
| 265 |
+
stream=True
|
| 266 |
+
)
|
| 267 |
+
tmp = next(response)
|
| 268 |
+
role = tmp.choices[0].delta["role"]
|
| 269 |
+
_type = "function_call" if tmp.choices[0].delta["content"] is None else "content"
|
| 270 |
+
if _type == "function_call":
|
| 271 |
+
name = tmp.choices[0].delta['function_call']['name']
|
| 272 |
+
yield _type, ChatCompletionWithHistory(state="success", role=role,
|
| 273 |
+
content="{" + f'"name":"{name}", "arguments":',
|
| 274 |
+
message_scratchpad=message)
|
| 275 |
+
for resp in response:
|
| 276 |
+
# print(resp)
|
| 277 |
+
content = resp.choices[0].delta.get(_type, "")
|
| 278 |
+
if isinstance(content, dict):
|
| 279 |
+
content = content['arguments']
|
| 280 |
+
yield _type, ChatCompletionWithHistory(state="success",
|
| 281 |
+
role=role,
|
| 282 |
+
content=content,
|
| 283 |
+
message_scratchpad=message)
|
| 284 |
+
|
| 285 |
+
# result = ''.join(messages)
|
| 286 |
+
# if _type == "function_call":
|
| 287 |
+
# result = json.loads(result)
|
| 288 |
+
# function_name = result["name"]
|
| 289 |
+
# fuction_to_call = function_map[function_name]
|
| 290 |
+
# function_args = result["arguments"]
|
| 291 |
+
# function_response = fuction_to_call(**function_args)
|
| 292 |
+
#
|
| 293 |
+
# # Postprocess function response
|
| 294 |
+
# if isinstance(function_response, AgentOutput):
|
| 295 |
+
# function_response = function_response.output
|
| 296 |
+
# message.append({"role": "function",
|
| 297 |
+
# "name": function_name,
|
| 298 |
+
# "content": function_response})
|
| 299 |
+
# second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema)
|
| 300 |
+
# message.append(dict(second_response.choices[0].message))
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
except Exception as exception:
|
| 304 |
+
raise exception
|
| 305 |
+
print("Exception:", exception)
|
| 306 |
+
return ChatCompletion(state="error", content=str(exception))
|
src/infiagent/llm/client/opt.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from abc import ABC
|
| 5 |
+
from typing import Callable, List
|
| 6 |
+
|
| 7 |
+
import openai
|
| 8 |
+
from tenacity import ( # for exponential backoff
|
| 9 |
+
before_sleep_log,
|
| 10 |
+
retry,
|
| 11 |
+
stop_after_attempt,
|
| 12 |
+
wait_random_exponential,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from ..base_llm import BaseLLM
|
| 16 |
+
from ...schemas import *
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
MAX_PROMPT_LENGTH = 7000
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True,
|
| 24 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
| 25 |
+
def chatcompletion_with_backoff(**kwargs):
|
| 26 |
+
return openai.ChatCompletion.create(**kwargs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True,
|
| 30 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
| 31 |
+
async def async_chatcompletion_with_backoff(**kwargs):
|
| 32 |
+
async def _internal_coroutine():
|
| 33 |
+
return await openai.ChatCompletion.acreate(**kwargs)
|
| 34 |
+
|
| 35 |
+
return await _internal_coroutine()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class OptOpenAIClient(BaseLLM, ABC):
|
| 39 |
+
"""
|
| 40 |
+
Wrapper class for OpenAI GPT API collections.
|
| 41 |
+
|
| 42 |
+
:param model_name: The name of the model to use.
|
| 43 |
+
:type model_name: str
|
| 44 |
+
:param params: The parameters for the model.
|
| 45 |
+
:type params: OptParamModel
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
model_name: str
|
| 49 |
+
params: OptParamModel = OptParamModel()
|
| 50 |
+
|
| 51 |
+
def __init__(self, **data):
|
| 52 |
+
super().__init__(**data)
|
| 53 |
+
openai.api_key = "EMPTY"
|
| 54 |
+
openai.api_base = "http://localhost:8000/v1"
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
async def create(cls, config_data):
|
| 58 |
+
return OptOpenAIClient(**config_data)
|
| 59 |
+
|
| 60 |
+
def get_model_name(self) -> str:
|
| 61 |
+
return self.model_name
|
| 62 |
+
|
| 63 |
+
def get_model_param(self) -> OptParamModel:
|
| 64 |
+
return self.params
|
| 65 |
+
|
| 66 |
+
def completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 67 |
+
"""
|
| 68 |
+
Completion method for OpenAI GPT API.
|
| 69 |
+
|
| 70 |
+
:param prompt: The prompt to use for completion.
|
| 71 |
+
:type prompt: str
|
| 72 |
+
:param kwargs: Additional keyword arguments.
|
| 73 |
+
:type kwargs: dict
|
| 74 |
+
:return: BaseCompletion object.
|
| 75 |
+
:rtype: BaseCompletion
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
response = chatcompletion_with_backoff(
|
| 80 |
+
model=self.model_name,
|
| 81 |
+
# engine=self.get_model_name(), # GPT-4
|
| 82 |
+
messages=[
|
| 83 |
+
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
|
| 84 |
+
],
|
| 85 |
+
timeout=1000,
|
| 86 |
+
**kwargs
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return BaseCompletion(state="success",
|
| 90 |
+
content=response.choices[0].message["content"],
|
| 91 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 92 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 93 |
+
|
| 94 |
+
async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
|
| 95 |
+
"""
|
| 96 |
+
Completion method for OpenAI GPT API.
|
| 97 |
+
|
| 98 |
+
:param prompt: The prompt to use for completion.
|
| 99 |
+
:type prompt: str
|
| 100 |
+
:param kwargs: Additional keyword arguments.
|
| 101 |
+
:type kwargs: dict
|
| 102 |
+
:return: BaseCompletion object.
|
| 103 |
+
:rtype: BaseCompletion
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
response = await async_chatcompletion_with_backoff(
|
| 107 |
+
# engine=self.get_model_name(), # GPT-4
|
| 108 |
+
model=self.model_name,
|
| 109 |
+
messages=[
|
| 110 |
+
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
|
| 111 |
+
],
|
| 112 |
+
timeout=1000,
|
| 113 |
+
**kwargs
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return BaseCompletion(state="success",
|
| 117 |
+
content=response.choices[0].message["content"],
|
| 118 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 119 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0))
|
| 120 |
+
|
| 121 |
+
def chat_completion(self, message: List[dict]) -> ChatCompletion:
|
| 122 |
+
"""
|
| 123 |
+
Chat completion method for OpenAI GPT API.
|
| 124 |
+
|
| 125 |
+
:param message: The message to use for completion.
|
| 126 |
+
:type message: List[dict]
|
| 127 |
+
:return: ChatCompletion object.
|
| 128 |
+
:rtype: ChatCompletion
|
| 129 |
+
"""
|
| 130 |
+
try:
|
| 131 |
+
# response = openai.ChatCompletion.create(
|
| 132 |
+
# engine=self.get_model_name(), # GPT-4
|
| 133 |
+
# messages=message,
|
| 134 |
+
# timeout=1000,
|
| 135 |
+
# )
|
| 136 |
+
response = openai.ChatCompletion.create(
|
| 137 |
+
n=self.params.n,
|
| 138 |
+
model=self.model_name,
|
| 139 |
+
messages=message,
|
| 140 |
+
temperature=self.params.temperature,
|
| 141 |
+
max_tokens=self.params.max_tokens,
|
| 142 |
+
top_p=self.params.top_p,
|
| 143 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 144 |
+
presence_penalty=self.params.presence_penalty,
|
| 145 |
+
)
|
| 146 |
+
return ChatCompletion(
|
| 147 |
+
state="success",
|
| 148 |
+
role=response.choices[0].message["role"],
|
| 149 |
+
content=response.choices[0].message["content"],
|
| 150 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 151 |
+
completion_token=response.get("usage", {}).get("completion_tokens", 0),
|
| 152 |
+
)
|
| 153 |
+
except Exception as exception:
|
| 154 |
+
print("Exception:", exception)
|
| 155 |
+
return ChatCompletion(state="error", content=exception)
|
| 156 |
+
|
| 157 |
+
def stream_chat_completion(self, message: List[dict], **kwargs):
|
| 158 |
+
"""
|
| 159 |
+
Stream output chat completion for OpenAI GPT API.
|
| 160 |
+
|
| 161 |
+
:param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
|
| 162 |
+
:type message: List[dict]
|
| 163 |
+
:param kwargs: Additional keyword arguments.
|
| 164 |
+
:type kwargs: dict
|
| 165 |
+
:return: ChatCompletion object.
|
| 166 |
+
:rtype: ChatCompletion
|
| 167 |
+
"""
|
| 168 |
+
try:
|
| 169 |
+
# response = openai.ChatCompletion.create(
|
| 170 |
+
# engine=self.get_model_name(), # GPT-4
|
| 171 |
+
# messages=message,
|
| 172 |
+
# timeout=1000,
|
| 173 |
+
# **kwargs,
|
| 174 |
+
# )
|
| 175 |
+
response = openai.ChatCompletion.create(
|
| 176 |
+
n=self.params.n,
|
| 177 |
+
model=self.model_name,
|
| 178 |
+
messages=message,
|
| 179 |
+
temperature=self.params.temperature,
|
| 180 |
+
max_tokens=self.params.max_tokens,
|
| 181 |
+
top_p=self.params.top_p,
|
| 182 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 183 |
+
presence_penalty=self.params.presence_penalty,
|
| 184 |
+
stream=True,
|
| 185 |
+
**kwargs
|
| 186 |
+
)
|
| 187 |
+
role = next(response).choices[0].delta["role"]
|
| 188 |
+
messages = []
|
| 189 |
+
## TODO: Calculate prompt_token and for stream mode
|
| 190 |
+
for resp in response:
|
| 191 |
+
messages.append(resp.choices[0].delta.get("content", ""))
|
| 192 |
+
yield ChatCompletion(
|
| 193 |
+
state="success",
|
| 194 |
+
role=role,
|
| 195 |
+
content=messages[-1],
|
| 196 |
+
prompt_token=0,
|
| 197 |
+
completion_token=0,
|
| 198 |
+
)
|
| 199 |
+
except Exception as exception:
|
| 200 |
+
print("Exception:", exception)
|
| 201 |
+
return ChatCompletion(state="error", content=exception)
|
| 202 |
+
|
| 203 |
+
def function_chat_completion(
|
| 204 |
+
self,
|
| 205 |
+
message: List[dict],
|
| 206 |
+
function_map: Dict[str, Callable],
|
| 207 |
+
function_schema: List[Dict],
|
| 208 |
+
) -> ChatCompletionWithHistory:
|
| 209 |
+
"""
|
| 210 |
+
Chat completion method for OpenAI GPT API.
|
| 211 |
+
|
| 212 |
+
:param message: The message to use for completion.
|
| 213 |
+
:type message: List[dict]
|
| 214 |
+
:param function_map: The function map to use for completion.
|
| 215 |
+
:type function_map: Dict[str, Callable]
|
| 216 |
+
:param function_schema: The function schema to use for completion.
|
| 217 |
+
:type function_schema: List[Dict]
|
| 218 |
+
:return: ChatCompletionWithHistory object.
|
| 219 |
+
:rtype: ChatCompletionWithHistory
|
| 220 |
+
"""
|
| 221 |
+
assert len(function_schema) == len(function_map)
|
| 222 |
+
try:
|
| 223 |
+
# response = openai.ChatCompletion.create(
|
| 224 |
+
# engine=self.get_model_name(), # GPT-4
|
| 225 |
+
# messages=message,
|
| 226 |
+
# functions=function_schema,
|
| 227 |
+
# timeout=1000,
|
| 228 |
+
# )
|
| 229 |
+
response = openai.ChatCompletion.create(
|
| 230 |
+
n=self.params.n,
|
| 231 |
+
model=self.model_name,
|
| 232 |
+
messages=message,
|
| 233 |
+
functions=function_schema,
|
| 234 |
+
temperature=self.params.temperature,
|
| 235 |
+
max_tokens=self.params.max_tokens,
|
| 236 |
+
top_p=self.params.top_p,
|
| 237 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 238 |
+
presence_penalty=self.params.presence_penalty,
|
| 239 |
+
)
|
| 240 |
+
response_message = response.choices[0]["message"]
|
| 241 |
+
|
| 242 |
+
if response_message.get("function_call"):
|
| 243 |
+
function_name = response_message["function_call"]["name"]
|
| 244 |
+
fuction_to_call = function_map[function_name]
|
| 245 |
+
function_args = json.loads(
|
| 246 |
+
response_message["function_call"]["arguments"]
|
| 247 |
+
)
|
| 248 |
+
function_response = fuction_to_call(**function_args)
|
| 249 |
+
|
| 250 |
+
# Postprocess function response
|
| 251 |
+
if isinstance(function_response, str):
|
| 252 |
+
plugin_cost = 0
|
| 253 |
+
plugin_token = 0
|
| 254 |
+
elif isinstance(function_response, AgentOutput):
|
| 255 |
+
plugin_cost = function_response.cost
|
| 256 |
+
plugin_token = function_response.token_usage
|
| 257 |
+
function_response = function_response.output
|
| 258 |
+
else:
|
| 259 |
+
raise Exception(
|
| 260 |
+
"Invalid tool response type. Must be on of [AgentOutput, str]"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
message.append(dict(response_message))
|
| 264 |
+
message.append(
|
| 265 |
+
{
|
| 266 |
+
"role": "function",
|
| 267 |
+
"name": function_name,
|
| 268 |
+
"content": function_response,
|
| 269 |
+
}
|
| 270 |
+
)
|
| 271 |
+
second_response = openai.ChatCompletion.create(
|
| 272 |
+
model=self.get_model_name(),
|
| 273 |
+
messages=message,
|
| 274 |
+
)
|
| 275 |
+
message.append(dict(second_response.choices[0].message))
|
| 276 |
+
return ChatCompletionWithHistory(
|
| 277 |
+
state="success",
|
| 278 |
+
role=second_response.choices[0].message["role"],
|
| 279 |
+
content=second_response.choices[0].message["content"],
|
| 280 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0)
|
| 281 |
+
+ second_response.get("usage", {}).get("prompt_tokens", 0),
|
| 282 |
+
completion_token=response.get("usage", {}).get(
|
| 283 |
+
"completion_tokens", 0
|
| 284 |
+
)
|
| 285 |
+
+ second_response.get("usage", {}).get("completion_tokens", 0),
|
| 286 |
+
message_scratchpad=message,
|
| 287 |
+
plugin_cost=plugin_cost,
|
| 288 |
+
plugin_token=plugin_token,
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
message.append(dict(response_message))
|
| 292 |
+
return ChatCompletionWithHistory(
|
| 293 |
+
state="success",
|
| 294 |
+
role=response.choices[0].message["role"],
|
| 295 |
+
content=response.choices[0].message["content"],
|
| 296 |
+
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
|
| 297 |
+
completion_token=response.get("usage", {}).get(
|
| 298 |
+
"completion_tokens", 0
|
| 299 |
+
),
|
| 300 |
+
message_scratchpad=message,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
except Exception as exception:
|
| 304 |
+
print("Exception:", exception)
|
| 305 |
+
return ChatCompletionWithHistory(state="error", content=str(exception))
|
| 306 |
+
|
| 307 |
+
def function_chat_stream_completion(
|
| 308 |
+
self,
|
| 309 |
+
message: List[dict],
|
| 310 |
+
function_map: Dict[str, Callable],
|
| 311 |
+
function_schema: List[Dict],
|
| 312 |
+
) -> ChatCompletionWithHistory:
|
| 313 |
+
assert len(function_schema) == len(function_map)
|
| 314 |
+
try:
|
| 315 |
+
response = openai.ChatCompletion.create(
|
| 316 |
+
n=self.params.n,
|
| 317 |
+
model=self.get_model_name(),
|
| 318 |
+
messages=message,
|
| 319 |
+
functions=function_schema,
|
| 320 |
+
temperature=self.params.temperature,
|
| 321 |
+
max_tokens=self.params.max_tokens,
|
| 322 |
+
top_p=self.params.top_p,
|
| 323 |
+
frequency_penalty=self.params.frequency_penalty,
|
| 324 |
+
presence_penalty=self.params.presence_penalty,
|
| 325 |
+
stream=True,
|
| 326 |
+
)
|
| 327 |
+
tmp = next(response)
|
| 328 |
+
role = tmp.choices[0].delta["role"]
|
| 329 |
+
_type = (
|
| 330 |
+
"function_call"
|
| 331 |
+
if tmp.choices[0].delta["content"] is None
|
| 332 |
+
else "content"
|
| 333 |
+
)
|
| 334 |
+
if _type == "function_call":
|
| 335 |
+
name = tmp.choices[0].delta["function_call"]["name"]
|
| 336 |
+
yield _type, ChatCompletionWithHistory(
|
| 337 |
+
state="success",
|
| 338 |
+
role=role,
|
| 339 |
+
content="{" + f'"name":"{name}", "arguments":',
|
| 340 |
+
message_scratchpad=message,
|
| 341 |
+
)
|
| 342 |
+
for resp in response:
|
| 343 |
+
# print(resp)
|
| 344 |
+
content = resp.choices[0].delta.get(_type, "")
|
| 345 |
+
if isinstance(content, dict):
|
| 346 |
+
content = content["arguments"]
|
| 347 |
+
yield _type, ChatCompletionWithHistory(
|
| 348 |
+
state="success",
|
| 349 |
+
role=role,
|
| 350 |
+
content=content,
|
| 351 |
+
message_scratchpad=message,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# result = ''.join(messages)
|
| 355 |
+
# if _type == "function_call":
|
| 356 |
+
# result = json.loads(result)
|
| 357 |
+
# function_name = result["name"]
|
| 358 |
+
# fuction_to_call = function_map[function_name]
|
| 359 |
+
# function_args = result["arguments"]
|
| 360 |
+
# function_response = fuction_to_call(**function_args)
|
| 361 |
+
#
|
| 362 |
+
# # Postprocess function response
|
| 363 |
+
# if isinstance(function_response, AgentOutput):
|
| 364 |
+
# function_response = function_response.output
|
| 365 |
+
# message.append({"role": "function",
|
| 366 |
+
# "name": function_name,
|
| 367 |
+
# "content": function_response})
|
| 368 |
+
# second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema)
|
| 369 |
+
# message.append(dict(second_response.choices[0].message))
|
| 370 |
+
|
| 371 |
+
except Exception as e:
|
| 372 |
+
logger.error(f"Failed to get response {str(e)}", exc_info=True)
|
| 373 |
+
raise e
|
src/infiagent/prompt/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .prompt_template import *
|
| 2 |
+
from .simple_react_prompt import SimpleReactPrompt
|
| 3 |
+
from .zero_shot_react_prompt import ZeroShotReactPrompt
|
src/infiagent/prompt/prompt_template.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt schema definition."""
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from string import Formatter
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, Extra, root_validator
|
| 7 |
+
|
| 8 |
+
from ..exceptions.exceptions import InputErrorException
|
| 9 |
+
from ..schemas import AgentAction, AgentObservation, BaseAgentResponse
|
| 10 |
+
|
| 11 |
+
OBSERVATION_KEY = "Observation"
|
| 12 |
+
THOUGHT_KEY = "Thought"
|
| 13 |
+
FINAL_ANSWER_KEY = "FinalAnswer"
|
| 14 |
+
|
| 15 |
+
DEFAULT_OBSERVATION = "Observation:"
|
| 16 |
+
DEFAULT_THOUGHT = "Thought:"
|
| 17 |
+
DEFAULT_FINAL_ANSWER = "Final Answer:"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PromptTemplate(BaseModel, ABC):
|
| 21 |
+
_input_variables: List[str]
|
| 22 |
+
_template: str
|
| 23 |
+
_keywords: Dict[str, str]
|
| 24 |
+
_name: str
|
| 25 |
+
_validate_template: bool
|
| 26 |
+
_skip_on_failure: bool
|
| 27 |
+
|
| 28 |
+
class Config:
|
| 29 |
+
extra = Extra.forbid
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def input_variables(self) -> List[str]:
|
| 33 |
+
return self._input_variables
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def template(self) -> str:
|
| 37 |
+
return self._template
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def keywords(self) -> Dict[str, str]:
|
| 41 |
+
return self._keywords
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def name(self) -> str:
|
| 45 |
+
return self._name
|
| 46 |
+
|
| 47 |
+
def format(self, **kwargs):
|
| 48 |
+
if not set(self._input_variables).issubset(kwargs.keys()):
|
| 49 |
+
missing_keys = set(self._input_variables) - kwargs.keys()
|
| 50 |
+
raise InputErrorException(f"Missing keys in prompt template: {', '.join(missing_keys)}")
|
| 51 |
+
|
| 52 |
+
filtered_kwargs = {key: kwargs[key] for key in self._input_variables if key in kwargs}
|
| 53 |
+
|
| 54 |
+
return self._template.format(**filtered_kwargs)
|
| 55 |
+
|
| 56 |
+
def construct_scratchpad(self, intermediate_steps: List[BaseAgentResponse]) -> str:
|
| 57 |
+
"""Construct the scratchpad that lets the agent continue its thought process."""
|
| 58 |
+
thoughts = ""
|
| 59 |
+
|
| 60 |
+
for agent_response in intermediate_steps:
|
| 61 |
+
if isinstance(agent_response, AgentAction):
|
| 62 |
+
# for agent action, use thought
|
| 63 |
+
thoughts += agent_response.raw_output
|
| 64 |
+
elif isinstance(agent_response, AgentObservation):
|
| 65 |
+
# for agent observation use observation
|
| 66 |
+
thoughts += f"\n{self.keywords.get(OBSERVATION_KEY, DEFAULT_OBSERVATION)}\n" \
|
| 67 |
+
f"{agent_response.formatted_output}\n\n" \
|
| 68 |
+
f"{self.keywords.get(THOUGHT_KEY, DEFAULT_THOUGHT)}\n"
|
| 69 |
+
|
| 70 |
+
return thoughts
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
@root_validator(skip_on_failure=True)
|
| 74 |
+
def template_is_valid(cls, values: Dict) -> Dict:
|
| 75 |
+
"""Check that template and input variables are consistent."""
|
| 76 |
+
if values["validate_template"]:
|
| 77 |
+
try:
|
| 78 |
+
dummy_input = {var: "" for var in values["input_variables"]}
|
| 79 |
+
Formatter().format(values["template"], **dummy_input)
|
| 80 |
+
except KeyError as e:
|
| 81 |
+
raise InputErrorException("Invalid prompt schema; check for mismatched or missing input parameters. ")\
|
| 82 |
+
from e
|
| 83 |
+
return values
|
src/infiagent/prompt/simple_react_prompt.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..prompt import FINAL_ANSWER_KEY, OBSERVATION_KEY, THOUGHT_KEY, PromptTemplate
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SimpleReactPrompt(PromptTemplate):
|
| 5 |
+
_input_variables = ["instruction", "agent_scratchpad"]
|
| 6 |
+
_template = "{instruction} \n{agent_scratchpad}"
|
| 7 |
+
_keywords = {
|
| 8 |
+
OBSERVATION_KEY: "[EOS]Observation:",
|
| 9 |
+
THOUGHT_KEY: "[SEP]",
|
| 10 |
+
FINAL_ANSWER_KEY: "[END]"
|
| 11 |
+
}
|
| 12 |
+
_name = 'SimpleReactPrompt'
|
| 13 |
+
_validate_template = True
|
| 14 |
+
_skip_on_failure = True
|
| 15 |
+
|
| 16 |
+
def __init__(self, **data):
|
| 17 |
+
super().__init__(**data)
|
src/infiagent/prompt/zero_shot_react_prompt.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..prompt import PromptTemplate, OBSERVATION_KEY, THOUGHT_KEY, FINAL_ANSWER_KEY, DEFAULT_OBSERVATION, \
|
| 2 |
+
DEFAULT_THOUGHT, DEFAULT_FINAL_ANSWER
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ZeroShotReactPrompt(PromptTemplate):
|
| 6 |
+
_input_variables = ["instruction", "agent_scratchpad", "tool_names", "tool_description"]
|
| 7 |
+
_template = (
|
| 8 |
+
"Answer the following questions as best you can."
|
| 9 |
+
"You have access to the following tools:\n"
|
| 10 |
+
"{tool_description}.\n"
|
| 11 |
+
"Use the following format:\n\n"
|
| 12 |
+
"Question: the input question you must answer\n"
|
| 13 |
+
"Thought: you should always think about what to do\n\n"
|
| 14 |
+
"Action: the action to take, should be one of [{tool_names}]\n\n"
|
| 15 |
+
"Action Input:\n```python\n[the input to the action]\n```\n"
|
| 16 |
+
"Observation: the result of the action\n\n"
|
| 17 |
+
"... (this Thought/Action/Action Input/Observation can repeat N times)\n"
|
| 18 |
+
"Thought: I now know the final answer\n"
|
| 19 |
+
"Final Answer: the final answer to the original input question\n"
|
| 20 |
+
"If you have any files outputted write them to \"./\"\n"
|
| 21 |
+
"Do not use things like plot.show() as it will not work instead write them out \"./\"\n"
|
| 22 |
+
"Begin!\n\n"
|
| 23 |
+
"Question: {instruction}\nThought:\n"
|
| 24 |
+
"{agent_scratchpad}\n"
|
| 25 |
+
)
|
| 26 |
+
_keywords = {
|
| 27 |
+
OBSERVATION_KEY: DEFAULT_OBSERVATION,
|
| 28 |
+
THOUGHT_KEY: DEFAULT_THOUGHT,
|
| 29 |
+
FINAL_ANSWER_KEY: DEFAULT_FINAL_ANSWER
|
| 30 |
+
}
|
| 31 |
+
_name = 'ZeroShotReactPrompt'
|
| 32 |
+
_validate_template = True
|
| 33 |
+
_skip_on_failure = True
|
| 34 |
+
|
| 35 |
+
def __init__(self, **data):
|
| 36 |
+
super().__init__(**data)
|
src/infiagent/schemas/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_models import *
|
| 2 |
+
from .complete_models import *
|
| 3 |
+
from .sandbox_models import *
|
| 4 |
+
from .agent_models import *
|
| 5 |
+
from .llm_models import *
|
src/infiagent/schemas/agent_models.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import List, NamedTuple, Optional, Union
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from ..schemas.sandbox_models import *
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class BaseAgentResponse:
|
| 15 |
+
"""Base Agent step result, contains formatted output string."""
|
| 16 |
+
formatted_output: str
|
| 17 |
+
raw_output: str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class AgentAction(BaseAgentResponse):
|
| 22 |
+
"""
|
| 23 |
+
Agent's action to take.
|
| 24 |
+
"""
|
| 25 |
+
tool: str
|
| 26 |
+
tool_input: Union[str, dict]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class AgentObservation(BaseAgentResponse):
|
| 31 |
+
"""
|
| 32 |
+
Agent's action to take.
|
| 33 |
+
"""
|
| 34 |
+
tool: str
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class AgentFinish(BaseAgentResponse):
|
| 39 |
+
"""Agent's return value when finishing execution."""
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class AgentType(Enum):
|
| 44 |
+
"""
|
| 45 |
+
Enumerated type for agent types.
|
| 46 |
+
"""
|
| 47 |
+
openai = "openai"
|
| 48 |
+
react = "react"
|
| 49 |
+
rewoo = "rewoo"
|
| 50 |
+
vanilla = "vanilla"
|
| 51 |
+
openai_memory = "openai_memory"
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def get_agent_class(_type: AgentType):
|
| 55 |
+
"""
|
| 56 |
+
Get agent class from agent type.
|
| 57 |
+
:param _type: agent type
|
| 58 |
+
:return: agent class
|
| 59 |
+
"""
|
| 60 |
+
if _type == AgentType.react:
|
| 61 |
+
from ..agent.react import ReactAgent
|
| 62 |
+
return ReactAgent
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unknown agent type: {_type}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AgentOutput(BaseModel):
|
| 68 |
+
"""
|
| 69 |
+
Pydantic model for agent output.
|
| 70 |
+
"""
|
| 71 |
+
output: str
|
| 72 |
+
cost: float
|
| 73 |
+
token_usage: int
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class AgentRequest:
|
| 78 |
+
sandbox_id: Optional[str] = None
|
| 79 |
+
messages: List[Message] = field(default_factory=list)
|
| 80 |
+
input_files: List[MediaFile] = field(default_factory=list)
|
| 81 |
+
sandbox_status: Optional[SandboxStatus] = None
|
| 82 |
+
is_cn: bool = False
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class AgentResponse:
|
| 88 |
+
output_text: str
|
| 89 |
+
raw_output_text: str
|
| 90 |
+
output_files: List[MediaFile] = field(default_factory=list)
|
| 91 |
+
sandbox_id: Optional[str] = None
|
| 92 |
+
sandbox_status: Optional[SandboxStatus] = None
|
| 93 |
+
turn_level_prompt: Optional[List[str]] = None
|
| 94 |
+
turn_level_response: Optional[List[str]] = None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class RoleType(Enum):
|
| 98 |
+
User = 0
|
| 99 |
+
System = 1
|
| 100 |
+
Agent = 2
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def _missing_(cls, name):
|
| 104 |
+
# If the input is a string, perform case-insensitive matching
|
| 105 |
+
if isinstance(name, str):
|
| 106 |
+
for member in cls:
|
| 107 |
+
if member.name.lower() == name.lower():
|
| 108 |
+
return member
|
| 109 |
+
return super()._missing_(name)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class Message(abc.ABC):
|
| 114 |
+
role: RoleType
|
| 115 |
+
content: str
|
| 116 |
+
raw_content: str = ""
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def parse_from_dict(data):
|
| 120 |
+
data['role'] = RoleType(data['role'])
|
| 121 |
+
# Add a check for raw_content in legacy data
|
| 122 |
+
if 'raw_content' not in data:
|
| 123 |
+
data['raw_content'] = ""
|
| 124 |
+
return Message(**data)
|
| 125 |
+
|
| 126 |
+
def to_dict(self):
|
| 127 |
+
role_value = self.role.value if isinstance(self.role, RoleType) else self.role
|
| 128 |
+
return {
|
| 129 |
+
"role": role_value,
|
| 130 |
+
"content": self.content, # Fixed the missing comma here
|
| 131 |
+
"raw_content": self.raw_content
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class MediaFile:
|
| 137 |
+
file_name: Optional[str] = None
|
| 138 |
+
file_content: Optional[bytes] = None
|
| 139 |
+
tos_path: Optional[str] = None
|
| 140 |
+
sandbox_path: Optional[str] = None
|
| 141 |
+
|
| 142 |
+
def __dict__(self):
|
| 143 |
+
return {
|
| 144 |
+
'file_name': self.file_name if self.file_name is not None else "",
|
| 145 |
+
'file_content': self.file_content if self.file_content is not None else "",
|
| 146 |
+
'tos_path': self.tos_path if self.tos_path is not None else "",
|
| 147 |
+
'sandbox_path': self.sandbox_path if self.sandbox_path is not None else "",
|
| 148 |
+
}
|
src/infiagent/schemas/base_models.py
ADDED
|
File without changes
|
src/infiagent/schemas/complete_models.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from time import time
|
| 4 |
+
from typing import Any, Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from ..schemas.agent_models import Message
|
| 9 |
+
from ..utils.file_utils import get_file_name_and_path
|
| 10 |
+
|
| 11 |
+
# Definitions for inputs and outputs schema for /complete api
|
| 12 |
+
|
| 13 |
+
DEFAULT_TOP_P = 0.7
|
| 14 |
+
DEFAULT_TEMPERATURE = 1.0
|
| 15 |
+
DEFAULT_STREAM = False
|
| 16 |
+
|
| 17 |
+
FINISH_STATUS = "FINISH"
|
| 18 |
+
FAILED_STATUS = "FAILED"
|
| 19 |
+
PROCESSING_STATUS = "PROCESSING"
|
| 20 |
+
ASSISTANT = "assistant"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Main Input Model
|
| 24 |
+
class ChatCompleteRequest(BaseModel):
|
| 25 |
+
chat_id: str # unique chat id for given chat
|
| 26 |
+
code_interpreter: Optional[dict] = {}
|
| 27 |
+
messages: List[dict] = [] # chat message
|
| 28 |
+
model: str = "AZURE_OPEN_AI" # model name map to LLM conf
|
| 29 |
+
user: str
|
| 30 |
+
max_tokens: Optional[int] = None
|
| 31 |
+
message_conf: Optional[dict] = {}
|
| 32 |
+
n: Optional[int] = None
|
| 33 |
+
plugins: Optional[List[str]] = None
|
| 34 |
+
seed_conf: Optional[dict] = {}
|
| 35 |
+
stream: Optional[bool] = None
|
| 36 |
+
temperature: Optional[float] = None
|
| 37 |
+
top_p: Optional[float] = None
|
| 38 |
+
top_k: Optional[int] = None
|
| 39 |
+
webgpt: Optional[Dict[str, Any]] = None
|
| 40 |
+
webgpt_network: Optional[bool] = None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class MessageConf(BaseModel):
|
| 44 |
+
top_p: float = DEFAULT_TOP_P
|
| 45 |
+
temperature: float = DEFAULT_TEMPERATURE
|
| 46 |
+
top_k: Optional[int] = None
|
| 47 |
+
time_cost: int
|
| 48 |
+
code_interpreter: dict
|
| 49 |
+
gpt_engine_conf: dict
|
| 50 |
+
stream: bool
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Delta(BaseModel):
|
| 54 |
+
role: str
|
| 55 |
+
content: str
|
| 56 |
+
sid: str
|
| 57 |
+
status: str
|
| 58 |
+
end_turn: bool
|
| 59 |
+
parent_id: str
|
| 60 |
+
children_ids: Optional[Union[List[str], None]]
|
| 61 |
+
err_msg: str
|
| 62 |
+
creator: str
|
| 63 |
+
updater: str
|
| 64 |
+
ctime: str
|
| 65 |
+
utime: str
|
| 66 |
+
message_conf: MessageConf
|
| 67 |
+
|
| 68 |
+
def json(self, *args, **kwargs):
|
| 69 |
+
serialized_data = super().json(*args, **kwargs)
|
| 70 |
+
return serialized_data.replace("+00:00", "Z")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Choice(BaseModel):
|
| 74 |
+
index: int
|
| 75 |
+
delta: Delta
|
| 76 |
+
finish_reason: str
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ChatCompleteResponse(BaseModel):
|
| 80 |
+
id: str
|
| 81 |
+
created: int
|
| 82 |
+
choices: List[Choice]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def chat_request_to_message_conf(chat_request: ChatCompleteRequest) -> MessageConf:
|
| 86 |
+
input_files = {}
|
| 87 |
+
|
| 88 |
+
if chat_request.code_interpreter and "tos_key" in chat_request.code_interpreter:
|
| 89 |
+
input_file = chat_request.code_interpreter["tos_key"]
|
| 90 |
+
file_name, tos_path = get_file_name_and_path(input_file)
|
| 91 |
+
input_files = {"tos_key": file_name}
|
| 92 |
+
|
| 93 |
+
return MessageConf(
|
| 94 |
+
top_p=chat_request.top_p if chat_request.top_p is not None else DEFAULT_TOP_P,
|
| 95 |
+
temperature=chat_request.temperature if chat_request.temperature is not None else DEFAULT_TEMPERATURE,
|
| 96 |
+
code_interpreter=input_files,
|
| 97 |
+
time_cost=0,
|
| 98 |
+
gpt_engine_conf={},
|
| 99 |
+
stream=chat_request.stream if chat_request.stream is not None else DEFAULT_STREAM
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def chat_request_to_deltas(chat_request: ChatCompleteRequest) -> List[Delta]:
|
| 104 |
+
deltas = []
|
| 105 |
+
message_conf = chat_request_to_message_conf(chat_request)
|
| 106 |
+
|
| 107 |
+
for message in chat_request.messages:
|
| 108 |
+
delta = Delta(
|
| 109 |
+
role=ASSISTANT,
|
| 110 |
+
content=message["content"],
|
| 111 |
+
sid="",
|
| 112 |
+
status="FINISH",
|
| 113 |
+
end_turn=False,
|
| 114 |
+
parent_id="",
|
| 115 |
+
children_ids=None,
|
| 116 |
+
err_msg="",
|
| 117 |
+
creator=chat_request.user,
|
| 118 |
+
updater=chat_request.user,
|
| 119 |
+
ctime=current_utc_time_as_str(),
|
| 120 |
+
utime=current_utc_time_as_str(),
|
| 121 |
+
message_conf=message_conf
|
| 122 |
+
)
|
| 123 |
+
deltas.append(delta)
|
| 124 |
+
|
| 125 |
+
return deltas
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def chat_request_to_choices(chat_request: ChatCompleteRequest) -> List[Choice]:
|
| 129 |
+
deltas = chat_request_to_deltas(chat_request)
|
| 130 |
+
choices = []
|
| 131 |
+
|
| 132 |
+
for index, delta in enumerate(deltas):
|
| 133 |
+
choice = Choice(
|
| 134 |
+
index=index,
|
| 135 |
+
delta=delta,
|
| 136 |
+
finish_reason="stop"
|
| 137 |
+
)
|
| 138 |
+
choices.append(choice)
|
| 139 |
+
|
| 140 |
+
return choices
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def chat_request_to_response(chat_request: ChatCompleteRequest) -> ChatCompleteResponse:
|
| 144 |
+
return ChatCompleteResponse(
|
| 145 |
+
id=chat_request.chat_id,
|
| 146 |
+
created=int(time()),
|
| 147 |
+
choices=chat_request_to_choices(chat_request)
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def update_chat_response_with_message(chat_response: ChatCompleteResponse,
|
| 152 |
+
message: Message,
|
| 153 |
+
status: Union[str, None] = None) -> ChatCompleteResponse:
|
| 154 |
+
# Get the last Delta (if exists)
|
| 155 |
+
last_delta = chat_response.choices[-1].delta if chat_response.choices else None
|
| 156 |
+
updated_delta = Delta(
|
| 157 |
+
role=ASSISTANT, # map with front end
|
| 158 |
+
content=message.content,
|
| 159 |
+
sid=last_delta.sid if last_delta else "",
|
| 160 |
+
status=status if status is not None else FINISH_STATUS,
|
| 161 |
+
end_turn=False,
|
| 162 |
+
parent_id=last_delta.parent_id if last_delta else "",
|
| 163 |
+
children_ids=last_delta.children_ids if last_delta else None,
|
| 164 |
+
err_msg="",
|
| 165 |
+
creator=last_delta.creator if last_delta else None,
|
| 166 |
+
updater=last_delta.updater if last_delta else None,
|
| 167 |
+
ctime=last_delta.ctime if last_delta else current_utc_time_as_str(),
|
| 168 |
+
utime=current_utc_time_as_str(),
|
| 169 |
+
message_conf=MessageConf(
|
| 170 |
+
top_p=last_delta.message_conf.top_p if last_delta and last_delta.message_conf.top_p else DEFAULT_TOP_P,
|
| 171 |
+
temperature=last_delta.message_conf.temperature if last_delta and last_delta.message_conf.temperature else
|
| 172 |
+
DEFAULT_TEMPERATURE,
|
| 173 |
+
code_interpreter=last_delta.message_conf.code_interpreter
|
| 174 |
+
if last_delta and last_delta.message_conf.code_interpreter else {},
|
| 175 |
+
time_cost=0,
|
| 176 |
+
gpt_engine_conf={},
|
| 177 |
+
stream=last_delta.message_conf.stream if last_delta and last_delta.message_conf.stream is not None else
|
| 178 |
+
False
|
| 179 |
+
)
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
updated_choice = Choice(
|
| 183 |
+
index=0, # Since it's the only choice in the list
|
| 184 |
+
delta=updated_delta,
|
| 185 |
+
finish_reason="stop"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Update the ChatCompleteResponse to contain only the new Choice
|
| 189 |
+
chat_response.choices = [updated_choice]
|
| 190 |
+
return chat_response
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def current_utc_time_as_str() -> str:
|
| 194 |
+
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def create_empty_response():
|
| 198 |
+
# Dummy instance for Delta
|
| 199 |
+
delta = Delta(
|
| 200 |
+
role=ASSISTANT,
|
| 201 |
+
content="",
|
| 202 |
+
sid="",
|
| 203 |
+
status="",
|
| 204 |
+
end_turn=False,
|
| 205 |
+
parent_id="",
|
| 206 |
+
children_ids=None,
|
| 207 |
+
err_msg="",
|
| 208 |
+
creator="",
|
| 209 |
+
updater="",
|
| 210 |
+
ctime="",
|
| 211 |
+
utime="",
|
| 212 |
+
message_conf=MessageConf(
|
| 213 |
+
top_p=0.0,
|
| 214 |
+
temperature=0,
|
| 215 |
+
time_cost=0,
|
| 216 |
+
code_interpreter={},
|
| 217 |
+
gpt_engine_conf={},
|
| 218 |
+
stream=False
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Dummy instance for Choice
|
| 223 |
+
choice = Choice(
|
| 224 |
+
index=0,
|
| 225 |
+
delta=delta,
|
| 226 |
+
finish_reason=""
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Dummy instance for ChatCompleteResponse
|
| 230 |
+
response = ChatCompleteResponse(
|
| 231 |
+
id="",
|
| 232 |
+
created=0,
|
| 233 |
+
choices=[choice]
|
| 234 |
+
)
|
| 235 |
+
return response
|
| 236 |
+
|
src/infiagent/schemas/llm_models.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Dict, List, NamedTuple, Union
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import torch
|
| 11 |
+
except ImportError:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseCompletion(BaseModel):
|
| 17 |
+
state: str # "success" or "error"
|
| 18 |
+
content: str
|
| 19 |
+
prompt_token: int = 0
|
| 20 |
+
completion_token: int = 0
|
| 21 |
+
|
| 22 |
+
def to_dict(self):
|
| 23 |
+
return dict(
|
| 24 |
+
state=self.state,
|
| 25 |
+
content=self.content,
|
| 26 |
+
prompt_token=self.prompt_token,
|
| 27 |
+
completion_token=self.completion_token,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ChatCompletion(BaseCompletion):
|
| 32 |
+
role: str = "assistant" # "system" or "user" or "assistant"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ChatCompletionWithHistory(ChatCompletion):
|
| 36 |
+
"""Used for function call API"""
|
| 37 |
+
message_scratchpad: List[Dict] = []
|
| 38 |
+
plugin_cost: float = 0.0
|
| 39 |
+
plugin_token: float = 0.0
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BaseParamModel(BaseModel):
|
| 43 |
+
def __eq__(self, other):
|
| 44 |
+
return self.dict() == other.dict()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class OpenAIParamModel(BaseModel):
|
| 48 |
+
"""
|
| 49 |
+
OpenAI API parameters
|
| 50 |
+
"""
|
| 51 |
+
max_tokens: int = 2048
|
| 52 |
+
temperature: float = 0.2
|
| 53 |
+
top_p: float = 1.0
|
| 54 |
+
presence_penalty: float = 0.0
|
| 55 |
+
frequency_penalty: float = 0.0
|
| 56 |
+
n: int = 1
|
| 57 |
+
stop: list = []
|
| 58 |
+
|
| 59 |
+
class AzureOpenAIParamModel(BaseModel):
|
| 60 |
+
"""
|
| 61 |
+
AzureOpenAI API parameters
|
| 62 |
+
"""
|
| 63 |
+
max_tokens: int = 2048
|
| 64 |
+
temperature: float = 0.2
|
| 65 |
+
top_p: float = 1.0
|
| 66 |
+
presence_penalty: float = 0.0
|
| 67 |
+
frequency_penalty: float = 0.0
|
| 68 |
+
n: int = 1
|
| 69 |
+
stop: list = []
|
| 70 |
+
|
| 71 |
+
class LlamaParamModel(BaseModel):
|
| 72 |
+
"""
|
| 73 |
+
AzureOpenAI API parameters
|
| 74 |
+
"""
|
| 75 |
+
max_tokens: int = 4096
|
| 76 |
+
temperature: float = 0.2
|
| 77 |
+
top_p: float = 1.0
|
| 78 |
+
presence_penalty: float = 0.0
|
| 79 |
+
frequency_penalty: float = 0.0
|
| 80 |
+
n: int = 1
|
| 81 |
+
stop: list = []
|
| 82 |
+
|
| 83 |
+
class OptParamModel(BaseModel):
|
| 84 |
+
"""
|
| 85 |
+
AzureOpenAI API parameters
|
| 86 |
+
"""
|
| 87 |
+
max_tokens: int = 2048
|
| 88 |
+
temperature: float = 0.2
|
| 89 |
+
top_p: float = 1.0
|
| 90 |
+
n: int = 1
|
| 91 |
+
stop: list = []
|
src/infiagent/schemas/sandbox_models.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Any, List, Optional
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
class SandboxStatus(Enum):
|
| 6 |
+
"""
|
| 7 |
+
Enumerated type for agent types.
|
| 8 |
+
"""
|
| 9 |
+
success = "success"
|
| 10 |
+
failed = "failed"
|
| 11 |
+
timeout = "timeout"
|
| 12 |
+
|
| 13 |
+
class CodeOutput(BaseModel):
|
| 14 |
+
type: str
|
| 15 |
+
content: str
|
| 16 |
+
|
| 17 |
+
class ReturnedFile(BaseModel):
|
| 18 |
+
download_link: str
|
| 19 |
+
name: str
|
| 20 |
+
path: str
|
| 21 |
+
|
| 22 |
+
class CodeRunResult(BaseModel):
|
| 23 |
+
code_output_result: List[CodeOutput]
|
| 24 |
+
deleted_files: List[ReturnedFile]
|
| 25 |
+
new_generated_files: List[ReturnedFile]
|
| 26 |
+
|
| 27 |
+
class CodeRunData(BaseModel):
|
| 28 |
+
is_partial: bool
|
| 29 |
+
result: CodeRunResult
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RunCodeOutput(BaseModel):
|
| 33 |
+
code: int
|
| 34 |
+
message: str
|
| 35 |
+
data: Optional[CodeRunData]
|
| 36 |
+
|
| 37 |
+
class CreateSessionOutput(BaseModel):
|
| 38 |
+
code: int
|
| 39 |
+
message: str
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ErrorResponse(BaseModel):
|
| 43 |
+
code: int
|
| 44 |
+
message: str
|
| 45 |
+
data: Optional[Any]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class UploadOutput(BaseModel):
|
| 49 |
+
code: int
|
| 50 |
+
message: Optional[str]
|
| 51 |
+
data: Optional[str]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Model for successful response (assuming it's a text file for this example)
|
| 55 |
+
class DownloadSuccessOutput(BaseModel):
|
| 56 |
+
file_name: str # this is not part of server response. We must fill this field in client.
|
| 57 |
+
content: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class HeartbeatOutput(BaseModel):
|
| 61 |
+
code: Optional[int]
|
| 62 |
+
message: Optional[str]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class RefreshSandboxOutput(BaseModel):
|
| 66 |
+
code: Optional[int]
|
| 67 |
+
message: Optional[str]
|
| 68 |
+
|
| 69 |
+
|
src/infiagent/services/__init__.py
ADDED
|
File without changes
|
src/infiagent/services/chat_complete_service.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from typing import Any, Dict, List, Union
|
| 4 |
+
|
| 5 |
+
from fastapi import UploadFile
|
| 6 |
+
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 7 |
+
from werkzeug.datastructures import FileStorage
|
| 8 |
+
|
| 9 |
+
from ..conversation_sessions import CodeInterpreterSession
|
| 10 |
+
from ..exceptions.exceptions import (
|
| 11 |
+
DependencyException,
|
| 12 |
+
InputErrorException,
|
| 13 |
+
InternalErrorException,
|
| 14 |
+
ModelMaxIterationsException,
|
| 15 |
+
)
|
| 16 |
+
from ..schemas import Message, RoleType
|
| 17 |
+
from ..utils import get_logger
|
| 18 |
+
from ..tools import AsyncPythonSandBoxTool
|
| 19 |
+
|
| 20 |
+
logger = get_logger()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def predict(
|
| 24 |
+
prompt: str,
|
| 25 |
+
model_name: str,
|
| 26 |
+
config_path: str,
|
| 27 |
+
uploaded_files: Any,
|
| 28 |
+
**kwargs: Dict[str, Any]):
|
| 29 |
+
start_time = time.time()
|
| 30 |
+
|
| 31 |
+
# create new session
|
| 32 |
+
session = await CodeInterpreterSession.create(
|
| 33 |
+
model_name=model_name,
|
| 34 |
+
config_path=config_path,
|
| 35 |
+
**kwargs
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
files = upload_files(uploaded_files, session.session_id)
|
| 39 |
+
logger.info(f"Session Creation Latency: {time.time() - start_time}")
|
| 40 |
+
|
| 41 |
+
# upload file
|
| 42 |
+
if isinstance(files, str):
|
| 43 |
+
logger.info(f"Upload {files} as file path")
|
| 44 |
+
await session.upload_to_sandbox(files)
|
| 45 |
+
# upload list of file
|
| 46 |
+
elif isinstance(files, list):
|
| 47 |
+
for file in files:
|
| 48 |
+
if isinstance(file, str):
|
| 49 |
+
await session.upload_to_sandbox(file)
|
| 50 |
+
elif isinstance(file, UploadFile) or isinstance(file, StarletteUploadFile):
|
| 51 |
+
file_content = file.file.read() # get file content
|
| 52 |
+
file_like_object = BytesIO(file_content)
|
| 53 |
+
file_storage = FileStorage(
|
| 54 |
+
stream=file_like_object,
|
| 55 |
+
filename=file.filename,
|
| 56 |
+
content_type=file.content_type
|
| 57 |
+
)
|
| 58 |
+
await session.upload_to_sandbox(file_storage)
|
| 59 |
+
else:
|
| 60 |
+
raise InputErrorException("The file type {} not supported, can't be uploaded".format(type(file)))
|
| 61 |
+
|
| 62 |
+
# chat
|
| 63 |
+
try:
|
| 64 |
+
logger.info(f"Instruction message: {prompt}")
|
| 65 |
+
content = None
|
| 66 |
+
output_files = []
|
| 67 |
+
user_messages = [Message(RoleType.User, prompt)]
|
| 68 |
+
async for response in session.chat(user_messages):
|
| 69 |
+
logger.info(f'Session Chat Response: {response}')
|
| 70 |
+
if content is None:
|
| 71 |
+
content = response.output_text
|
| 72 |
+
else:
|
| 73 |
+
content += response.output_text
|
| 74 |
+
|
| 75 |
+
output_files.extend([output_file.__dict__() for output_file in response.output_files])
|
| 76 |
+
|
| 77 |
+
session.messages.append(Message(RoleType.Agent, content))
|
| 78 |
+
AsyncPythonSandBoxTool.kill_kernels(session.session_id)
|
| 79 |
+
logger.info(f"Release python sandbox {session.session_id}")
|
| 80 |
+
logger.info(f"Total Latency: {time.time() - start_time}")
|
| 81 |
+
|
| 82 |
+
return content
|
| 83 |
+
|
| 84 |
+
except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \
|
| 85 |
+
as e:
|
| 86 |
+
exception_messages = {
|
| 87 |
+
ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, "
|
| 88 |
+
"Please try another question.",
|
| 89 |
+
DependencyException: "Agent failed to process message due to dependency issue. You can try it later. "
|
| 90 |
+
"If it still happens, please contact oncall.",
|
| 91 |
+
InputErrorException: "Agent failed to process message due to value issue. If you believe all input are "
|
| 92 |
+
"correct, please contact oncall.",
|
| 93 |
+
InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.",
|
| 94 |
+
Exception: "Agent failed to process message due to unknown error, please contact oncall."
|
| 95 |
+
}
|
| 96 |
+
err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}")
|
| 97 |
+
logger.error(err_msg, exc_info=True)
|
| 98 |
+
|
| 99 |
+
raise Exception(err_msg)
|
| 100 |
+
|
| 101 |
+
import time
|
| 102 |
+
from typing import Union, List, Any, Dict
|
| 103 |
+
from io import BytesIO
|
| 104 |
+
|
| 105 |
+
from fastapi import UploadFile
|
| 106 |
+
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 107 |
+
|
| 108 |
+
from ..conversation_sessions import CodeInterpreterSession
|
| 109 |
+
from ..schemas import (
|
| 110 |
+
Message,
|
| 111 |
+
RoleType
|
| 112 |
+
)
|
| 113 |
+
from werkzeug.datastructures import FileStorage
|
| 114 |
+
|
| 115 |
+
from ..exceptions.exceptions import InputErrorException, DependencyException, InternalErrorException, \
|
| 116 |
+
ModelMaxIterationsException
|
| 117 |
+
|
| 118 |
+
from ..utils import get_logger, upload_files
|
| 119 |
+
|
| 120 |
+
logger = get_logger()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
async def predict(
|
| 124 |
+
prompt: str,
|
| 125 |
+
model_name: str,
|
| 126 |
+
uploaded_files: Any,
|
| 127 |
+
**kwargs: Dict[str, Any]):
|
| 128 |
+
start_time = time.time()
|
| 129 |
+
|
| 130 |
+
# create new session
|
| 131 |
+
session = await CodeInterpreterSession.create(
|
| 132 |
+
model_name=model_name,
|
| 133 |
+
**kwargs
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
files = upload_files(uploaded_files, session.session_id)
|
| 137 |
+
logger.info(f"Session Creation Latency: {time.time() - start_time}")
|
| 138 |
+
|
| 139 |
+
# upload file
|
| 140 |
+
if isinstance(files, str):
|
| 141 |
+
logger.info(f"Upload {files} as file path")
|
| 142 |
+
await session.upload_to_sandbox(files)
|
| 143 |
+
# upload list of file
|
| 144 |
+
elif isinstance(files, list):
|
| 145 |
+
for file in files:
|
| 146 |
+
if isinstance(file, str):
|
| 147 |
+
await session.upload_to_sandbox(file)
|
| 148 |
+
elif isinstance(file, UploadFile) or isinstance(file, StarletteUploadFile):
|
| 149 |
+
file_content = file.file.read() # get file content
|
| 150 |
+
file_like_object = BytesIO(file_content)
|
| 151 |
+
file_storage = FileStorage(
|
| 152 |
+
stream=file_like_object,
|
| 153 |
+
filename=file.filename,
|
| 154 |
+
content_type=file.content_type
|
| 155 |
+
)
|
| 156 |
+
await session.upload_to_sandbox(file_storage)
|
| 157 |
+
else:
|
| 158 |
+
raise InputErrorException("The file type {} not supported, can't be uploaded".format(type(file)))
|
| 159 |
+
|
| 160 |
+
# chat
|
| 161 |
+
try:
|
| 162 |
+
logger.info(f"Instruction message: {prompt}")
|
| 163 |
+
content = None
|
| 164 |
+
output_files = []
|
| 165 |
+
user_messages = [Message(RoleType.User, prompt)]
|
| 166 |
+
|
| 167 |
+
async for response in session.chat(user_messages):
|
| 168 |
+
logger.info(f'Session Chat Response: {response}')
|
| 169 |
+
if content is None:
|
| 170 |
+
content = response.output_text
|
| 171 |
+
else:
|
| 172 |
+
content += response.output_text
|
| 173 |
+
|
| 174 |
+
output_files.extend([output_file.__dict__() for output_file in response.output_files])
|
| 175 |
+
|
| 176 |
+
session.messages.append(Message(RoleType.Agent, content))
|
| 177 |
+
|
| 178 |
+
logger.info(f"Total Latency: {time.time() - start_time}")
|
| 179 |
+
|
| 180 |
+
return content
|
| 181 |
+
except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \
|
| 182 |
+
as e:
|
| 183 |
+
exception_messages = {
|
| 184 |
+
ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, "
|
| 185 |
+
"Please try another question.",
|
| 186 |
+
DependencyException: "Agent failed to process message due to dependency issue. You can try it later. "
|
| 187 |
+
"If it still happens, please contact oncall.",
|
| 188 |
+
InputErrorException: "Agent failed to process message due to value issue. If you believe all input are "
|
| 189 |
+
"correct, please contact oncall.",
|
| 190 |
+
InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.",
|
| 191 |
+
Exception: "Agent failed to process message due to unknown error, please contact oncall."
|
| 192 |
+
}
|
| 193 |
+
err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}")
|
| 194 |
+
logger.error(err_msg, exc_info=True)
|
| 195 |
+
|
| 196 |
+
raise Exception(err_msg)
|