diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8b65f432368d0a7ad6051982e4cee70c6e0c5577 --- /dev/null +++ b/.gitignore @@ -0,0 +1,131 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +build/doctrees +build/html + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# JetBrains PyCharm specific +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, GoLand, Rider and Android Studio +.idea/ +*.iml + +# User-specific stuff +*.swp +*~ +.Session.vim +/.sass-cache diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..2cd04f6f2fe557afda55fcc145a9c944cf1155ba --- /dev/null +++ b/Dockerfile @@ -0,0 +1,3 @@ +FROM python:3 + +RUN pip install pandas numpy scikit-learn matplotlib seaborn diff --git a/activities/activity_helpers.py b/activities/activity_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..386ce86c4e91b8598f2c98a90862e496487613e3 --- /dev/null +++ b/activities/activity_helpers.py @@ -0,0 +1,33 @@ +import json + +from sse_starlette import ServerSentEvent + +from infiagent.schemas import ResponseBaseData + + +IGNORE_PING_COMMENT = {"comment": "IGNORE PING"} +DONE = "[DONE]" + + +async def async_sse_response_format(response_data_gen): + async for content in response_data_gen: + if content == DONE: + sse_event = ServerSentEvent(data=DONE) + else: + data_dict = { + "response": content, + "ResponseBase": ResponseBaseData().dict() + } + sse_event = ServerSentEvent(data=json.dumps(data_dict, ensure_ascii=False)) + yield sse_event + + +def json_response_format(content): + return { + "response": content, + "ResponseBase": ResponseBaseData().dict() + } + + +def get_ignore_ping_comment(): + return lambda: ServerSentEvent(**IGNORE_PING_COMMENT) diff --git a/activities/api.py b/activities/api.py new file mode 100644 index 0000000000000000000000000000000000000000..75071a77d43c505475a63a20f2be0c1b292deced --- /dev/null +++ b/activities/api.py @@ -0,0 +1,93 @@ +import asyncio +import uuid + +import uvloop +from dotenv import load_dotenv +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from sse_starlette.sse import EventSourceResponse, ServerSentEvent +from starlette.responses import JSONResponse, Response + +from .activity_helpers import DONE +from .complete_chat import complete_chat_router +from .predict import predict_router + +try: + import infiagent + from infiagent.schemas import FailedResponseBaseData + from infiagent.utils import get_logger, init_logging, log_id_var +except ImportError: + print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") + from ..schemas import FailedResponseBaseData + from ..utils import get_logger, init_logging, log_id_var + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +SSE_API_PATHS = ["/complete_sse"] +LOG_ID_HEADER_NAME = "X-Tt-Logid" + + +load_dotenv() +init_logging() +logger = get_logger() + +app = FastAPI() +origins = ["*"] +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.include_router(complete_chat_router) +app.include_router(predict_router) + + +@app.middleware("http") +async def log_id_middleware(request: Request, call_next): + # Get X-Tt-Logid from request headers + log_id = request.headers.get(LOG_ID_HEADER_NAME) + if not log_id: + # Generate a log_id if not present in headers + log_id = str(uuid.uuid4()) + + log_id_var.set(log_id) + + response: Response = await call_next(request) + response.headers[LOG_ID_HEADER_NAME] = log_id_var.get() + return response + + +@app.exception_handler(Exception) +async def general_exception_handler(request, exc): + error_msg = "Failed to handle request. Internal Server error: {}".format(str(exc)) + logger.error(error_msg, exc_info=True) + + if request.url.path in SSE_API_PATHS: + return EventSourceResponse(ServerSentEvent(data=DONE)) + else: + return JSONResponse( + status_code=500, + content={ + "response": error_msg, + "ResponseBase": FailedResponseBaseData().dict() + } + ) + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request, exc): + error_msg = "Failed to handle request. Error: {}".format(exc.detail) + logger.error(error_msg, exc_info=True) + + if request.url.path in SSE_API_PATHS: + return EventSourceResponse(ServerSentEvent(data=DONE)) + else: + return JSONResponse( + status_code=exc.status_code, + content={ + "response": error_msg, + "ResponseBase": FailedResponseBaseData().dict() + } + ) diff --git a/activities/complete_chat.py b/activities/complete_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..79290a1c227ad598ee6014ab5f76f6b728ddb8af --- /dev/null +++ b/activities/complete_chat.py @@ -0,0 +1,77 @@ +from fastapi import APIRouter, Request, HTTPException +from pydantic import ValidationError +from sse_starlette import EventSourceResponse, ServerSentEvent + +from .activity_helpers import async_sse_response_format, IGNORE_PING_COMMENT, json_response_format + +try: + import infiagent + from infiagent.db.conversation_dao import ConversationDAO + from infiagent.schemas import ChatCompleteRequest + from infiagent.services.chat_complete_sse_service import chat_event_generator, chat_event_response + from infiagent.tools.code_sandbox.async_sandbox_client import AsyncSandboxClient + from infiagent.utils import get_logger +except ImportError: + print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") + from ..db.conversation_dao import ConversationDAO + from ..schemas import ChatCompleteRequest + from ..services.chat_complete_sse_service import chat_event_generator, chat_event_response + from ..tools.code_sandbox.async_sandbox_client import AsyncSandboxClient + from ..utils import get_logger + +complete_chat_router = APIRouter() +logger = get_logger() + + +@complete_chat_router.post("/complete_sse") +async def complete_sse(request: Request): + body_str = await request.body() + + try: + chat_request = ChatCompleteRequest.parse_raw(body_str) + logger.info("Got chat request: {}".format(chat_request)) + except ValidationError as e: + error_msg = "Invalid input chat_request. Error: {}".format(str(e)) + raise HTTPException(status_code=400, detail=error_msg) + + return EventSourceResponse(async_sse_response_format(chat_event_generator(chat_request)), + ping_message_factory=lambda: ServerSentEvent(**IGNORE_PING_COMMENT)) + + +@complete_chat_router.post("/complete") +async def complete(request: Request): + body_str = await request.body() + + try: + chat_request = ChatCompleteRequest.parse_raw(body_str) + logger.info("Got chat request: {}".format(chat_request)) + except ValidationError as e: + error_msg = "Invalid input chat_request. Error: {}".format(str(e)) + raise HTTPException(status_code=400, detail=error_msg) + + response_items = await chat_event_response(chat_request) + + return json_response_format(response_items) + + +@complete_chat_router.get("/heartbeat") +async def heartbeat(chat_id: str = None, session_id: str = None): + if not chat_id and not session_id: + raise HTTPException(status_code=400, detail="Either chat_id or session_id must be provided.") + + input_chat_id = chat_id or session_id + + conversation = await ConversationDAO.get_conversation(input_chat_id) + if not conversation: + logger.info(f'Call heartbeat on a non-exist conversion, {input_chat_id}') + return json_response_format("conversation is not created, skip") + + if conversation.sandbox_id is None: + logger.error(f'No sandbox id for heartbeat, chat id {input_chat_id}') + raise HTTPException(status_code=404, detail=f'No sandbox id for heartbeat, chat id {input_chat_id}') + + # TODO Add exception handling logic here for heartbeat failed in sandbox side + heartbeat_response = await AsyncSandboxClient(conversation.sandbox_id).heartbeat() + logger.info(f"Heartbeat response {heartbeat_response}") + + return json_response_format("succeed") diff --git a/activities/eval.py b/activities/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4b33ef5425b7074dc0291c85b5b54c34c5f0b882 --- /dev/null +++ b/activities/eval.py @@ -0,0 +1,207 @@ +import os +import re +import argparse +import asyncio +import logging +import sys +import json +import io + +import openai + + +import infiagent +from infiagent.utils import get_logger, upload_files, get_file_name_and_path +from infiagent.services.chat_complete_service import predict + + +logger = get_logger() + + +class UploadedFile(io.BytesIO): + def __init__(self, path): + with open(path, 'rb') as file: + data = file.read() + + super().__init__(data) + + self.name = path.split("/")[-1] # 获取文件名 + self.type = 'application/octet-stream' # 或者其他适当的 MIME 类型 + self.size = len(data) + + def __repr__(self): + return f"MyUploadedFile(name={self.name}, size={self.size}, type={self.type})" + + def __len__(self): + + return self.size + +# # 使用例子 +# file_path = "path/to/your/file" +# uploaded_file = MyUploadedFile(file_path) + +# print(uploaded_file) + + +def _get_script_params(): + try: + parser = argparse.ArgumentParser() + parser.add_argument('--llm', + help='LLM Model for demo', + required=False, type=str) + parser.add_argument('--api_key', + help='Open API token key.', + required=False, type=str) + + parser.add_argument('--config_path', + help='Config path for demo', + default="configs/agent_configs/react_agent_llama_async.yaml", + required=False, type=str) + + args = parser.parse_args() + + return args + except Exception as e: + logger.error("Failed to get script input arguments: {}".format(str(e)), exc_info=True) + + return None + + +def extract_questions_and_concepts(file_path): + # Read the content of the text file + with open(file_path, 'r') as file: + content = file.read() + + # Use regular expressions to extract questions and concepts + pattern = r'\\Question{(.*?)}\s*\\Concepts{(.*?)}' + matches = re.findall(pattern, content, re.DOTALL) + + # Build a list of dictionaries containing the questions and concepts + data = [] + for match in matches: + question = match[0].strip() + concepts = [concept.strip() for concept in match[1].split(',')] + data.append({ + 'question': question, + 'concepts': concepts + }) + + return data + +def read_dicts_from_file(file_name): + """ + Read a file with each line containing a JSON string representing a dictionary, + and return a list of dictionaries. + + :param file_name: Name of the file to read from. + :return: List of dictionaries. + """ + dict_list = [] + with open(file_name, 'r') as file: + for line in file: + # Convert the JSON string back to a dictionary. + dictionary = json.loads(line.rstrip('\n')) + dict_list.append(dictionary) + return dict_list + +def read_questions(file_path): + print(file_path) + with open(file_path) as f: + questions = json.load(f) + + return questions + +def extract_data_from_folder(folder_path): + + print(f'folder_path {folder_path}') + extracted_data = {} + # Traverse the files in the folder + for file_name in os.listdir(folder_path): + if file_name.endswith('.questions'): # You can filter files based on their type + file_path = os.path.join(folder_path, file_name) + file_data = read_questions(file_path) + file_name_without_extension = os.path.splitext(file_name)[0] + extracted_data[file_name_without_extension] = file_data + + return extracted_data + + +async def main(): + extracted_data = read_dicts_from_file('./data/da-dev-questions.jsonl') + args = _get_script_params() + + model_name = getattr(args, "llm", None) + open_ai_key = getattr(args, "api_key", None) + + if "OPEN_AI" in model_name: + logger.info("setup open ai ") + if os.environ.get("OPENAI_API_KEY") is None: + if open_ai_key: + openai.api_key = open_ai_key + os.environ["OPENAI_API_KEY"] = open_ai_key + else: + raise ValueError("OPENAI_API_KEY is None, please provide open ai key to use open ai model. Adding " + "'--api_key' to set it up") + + # 获取 'openai' 的 logger + openai_logger = logging.getLogger('openai') + # 设置日志级别为 'WARNING',这样 'INFO' 级别的日志就不会被打印了 + openai_logger.setLevel(logging.WARNING) + else: + logger.info("use local model ") + + table_path = 'data/da-dev-tables' + results = [] + + i = 1 + for q in extracted_data: + input_text = q['question'] + concepts = q['concepts'] + file_path = q['file_name'] + constraints = q['constraints'] + format = q['format'] + + file_path = os.path.join(table_path, file_path) + + print(f'input_text: {input_text}') + print(f'concepts: {concepts}') + print(f'file_path: {file_path}') + + uploaded_file = UploadedFile(file_path) + print(uploaded_file) + + prompt = f"Question: {input_text}\n{constraints}\n" + + response = await predict( + prompt=prompt, + model_name=model_name, + config_path=args.config_path, + uploaded_files=[uploaded_file] + ) + + iteration_result = { + 'id': q['id'], + 'input_text': prompt, + 'concepts': concepts, + 'file_path': file_path, + 'response': response, + 'format': format + } + results.append(iteration_result) + print(f"response: {response}") + + if i % 10 == 0: + with open('results_{}.json'.format(model_name), 'w') as outfile: + json.dump(results, outfile, indent=4) + + i += 1 + + with open('results_{}.json'.format(model_name), 'w') as outfile: + json.dump(results, outfile, indent=4) + + +if __name__ == '__main__': + asyncio.run(main()) + # main() + + diff --git a/activities/local_demo.py b/activities/local_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..16dcbc15d87334efc383f14db485e69f46e2089a --- /dev/null +++ b/activities/local_demo.py @@ -0,0 +1,108 @@ +import argparse +import asyncio +import logging +import os +import sys + +import streamlit as st # type: ignore +import uvloop +import openai + +try: + import infiagent + from infiagent.utils import get_logger, upload_files + from infiagent.services.chat_complete_service import predict +except ImportError: + raise ( + "import infiagent failed, please install infiagent by 'pip install -e .' in the pipeline directory of ADA-Agent") + +logger = get_logger() + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +def _get_script_params(): + try: + parser = argparse.ArgumentParser() + parser.add_argument('--llm', + help='LLM Model for demo', + required=False, type=str) + parser.add_argument('--api_key', + help='Open API token key.', + required=False, type=str) + parser.add_argument('--config_path', + help='Config path for demo', + # default="configs/agent_configs/react_agent_gpt4_async.yaml", + required=False, type=str) + + args = parser.parse_args() + + return args + except Exception as e: + logger.error("Failed to get script input arguments: {}".format(str(e)), exc_info=True) + + return None + + +async def main(): + args = _get_script_params() + + model_name = getattr(args, "llm", None) + open_ai_key = getattr(args, "api_key", None) + config_path = getattr(args, "config_path", None) + + if "OPEN_AI" in model_name: + logger.info("setup open ai ") + if os.environ.get("OPENAI_API_KEY") is None: + if open_ai_key: + openai.api_key = open_ai_key + os.environ["OPENAI_API_KEY"] = open_ai_key + else: + raise ValueError( + "OPENAI_API_KEY is None, please provide opekn ai key to use open ai model. Adding '--api_key' to set it up") + + # 获取 'openai' 的 logger + openai_logger = logging.getLogger('openai') + # 设置日志级别为 'WARNING',这样 'INFO' 级别的日志就不会被打印了 + openai_logger.setLevel(logging.WARNING) + + else: + logger.info("use local model ") + + st.set_page_config(layout="centered") + + st.title("InfiAgent Code Interpreter Demo 🚀") + + # Initialize session state variables if not already present + if 'chat_history' not in st.session_state: + st.session_state.chat_history = [] + + # UI components + input_text = st.text_area("Write your prompt") + uploaded_files = st.file_uploader("Upload your files", accept_multiple_files=True) + button_pressed = st.button("Run code interpreter", use_container_width=True) + + # When button is pressed + if button_pressed and input_text != "": + # Add user message to chat history + st.session_state.chat_history.append({"role": "user", "message": input_text}) + + # Predict response (assuming you have the necessary async handling) + response = await predict( + prompt=input_text, + model_name=model_name, + config_path=config_path, + uploaded_files=uploaded_files, + ) + + # Add assistant message to chat history + st.session_state.chat_history.append({"role": "assistant", "message": response}) + + # Display chat history + for chat in st.session_state.chat_history: + with st.chat_message(chat["role"]): + st.write(chat["message"]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/activities/local_test.py b/activities/local_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7a5aff68677f95dd9e546520838c2323750604 --- /dev/null +++ b/activities/local_test.py @@ -0,0 +1,87 @@ +import json + +from fastapi import FastAPI, HTTPException, Request +from pydantic import ValidationError +from sse_starlette import EventSourceResponse + +from .activity_helpers import ( + async_sse_response_format, + get_ignore_ping_comment, + json_response_format, +) + + +try: + import infiagent + from infiagent.schemas import ChatCompleteRequest + from infiagent.services.complete_local_test import ( + chat_local_event, + chat_local_event_generator, + ) + from infiagent.utils import get_logger +except ImportError: + print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") + from ..schemas import ChatCompleteRequest + from ..services.complete_local_test import ( + chat_local_event, + chat_local_event_generator, + ) + from ..utils import get_logger + +logger = get_logger() +local_app = FastAPI() + + +@local_app.post("/local_sse_test") +async def complete_sse(request: Request): + body_str = await request.body() + + try: + chat_request = ChatCompleteRequest.parse_raw(body_str) + logger.info("Got chat request: {}".format(chat_request)) + except ValidationError as e: + error_msg = "Invalid input chat_request. Error: {}".format(str(e)) + raise HTTPException(status_code=500, detail=error_msg) + + return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)), + ping_message_factory=get_ignore_ping_comment()) + + +@local_app.post("/local_json_test") +async def complete_json(request: Request): + + body_str = await request.body() + + try: + chat_request = ChatCompleteRequest.parse_raw(body_str) + logger.info("Got chat request: {}".format(chat_request)) + except ValidationError as e: + error_msg = "Invalid input chat_request. Error: {}".format(str(e)) + raise HTTPException(status_code=500, detail=error_msg) + + response_items = await chat_local_event(chat_request) + return json_response_format(response_items) + + +@local_app.post("/exception_test") +async def complete_json(request: Request): + body_str = await request.body() + + try: + chat_request = ChatCompleteRequest.parse_raw(body_str) + logger.info("Got chat request: {}".format(chat_request)) + except ValidationError as e: + error_msg = "Invalid input chat_request. Error: {}".format(str(e)) + raise HTTPException(status_code=500, detail=error_msg) + return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request))) + + +async def exception_test(request: Request): + body_str = await request.body() + json_val = json.loads(body_str) + exception_type = json_val.get("exception", None) + + if exception_type: + raise ValueError("Error triggerd!") + else: + yield iter(["Success"]) diff --git a/activities/predict.py b/activities/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..dd971048dce98ee7f62c694b93d2b5d4b82008a6 --- /dev/null +++ b/activities/predict.py @@ -0,0 +1,41 @@ +from fastapi import APIRouter, File, Form, UploadFile +from typing import List, Optional + +try: + import infiagent + from infiagent.services.chat_complete_service import predict +except ImportError: + print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") + from ..services.chat_complete_service import predict + +predict_router = APIRouter() + + +@predict_router.post("/predict") +async def chat_predict( + prompt: str = Form(...), + model_name: str = Form(...), + psm: Optional[str] = Form(None), + dc: Optional[str] = Form(None), + temperature: Optional[str] = Form(None), + top_p: Optional[str] = Form(None), + top_k: Optional[str] = Form(None), + files: List[UploadFile] = File(...) +): + kwargs = {} + if psm: + kwargs['psm'] = psm + if dc: + kwargs['dc'] = dc + if temperature: + kwargs['temperature'] = float(temperature) + if top_p: + kwargs['top_p'] = float(top_p) + if top_k: + kwargs['top_k'] = float(top_k) + + response = await predict(prompt, model_name, files, **kwargs) + + return { + "answer": response + } diff --git a/activities/vllm_api_server.py b/activities/vllm_api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb440ec2836b4e035c6f0114c8d2075d48cb8ff --- /dev/null +++ b/activities/vllm_api_server.py @@ -0,0 +1,636 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py + +import argparse +import asyncio +import json +import time +from http import HTTPStatus +from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union + +import fastapi +import uvicorn +from fastapi import Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse, Response +from packaging import version + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ( + CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import random_uuid + +try: + import fastchat + from fastchat.conversation import Conversation, SeparatorStyle + from fastchat.model.model_adapter import get_conversation_template + _fastchat_available = True +except ImportError: + _fastchat_available = False + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +logger = init_logger(__name__) +served_model = None +app = fastapi.FastAPI() +engine = None + + +def create_error_response(status_code: HTTPStatus, + message: str) -> JSONResponse: + return JSONResponse(ErrorResponse(message=message, + type="invalid_request_error").dict(), + status_code=status_code.value) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): # pylint: disable=unused-argument + return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + if request.model == served_model: + return + ret = create_error_response( + HTTPStatus.NOT_FOUND, + f"The model `{request.model}` does not exist.", + ) + return ret + + +async def get_gen_prompt(request) -> str: + if not _fastchat_available: + raise ModuleNotFoundError( + "fastchat is not installed. Please install fastchat to use " + "the chat completion and conversation APIs: `$ pip install fschat`" + ) + if version.parse(fastchat.__version__) < version.parse("0.2.23"): + raise ImportError( + f"fastchat version is low. Current version: {fastchat.__version__} " + "Please upgrade fastchat to use: `$ pip install -U fschat`") + + conv = get_conversation_template(request.model) + conv = Conversation( + name=conv.name, + system_template=conv.system_template, + system_message=conv.system_message, + roles=conv.roles, + messages=list(conv.messages), # prevent in-place modification + offset=conv.offset, + sep_style=SeparatorStyle(conv.sep_style), + sep=conv.sep, + sep2=conv.sep2, + stop_str=conv.stop_str, + stop_token_ids=conv.stop_token_ids, + ) + + if isinstance(request.messages, str): + prompt = request.messages + else: + for message in request.messages: + msg_role = message["role"] + if msg_role == "system": + conv.system_message = message["content"] + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt + + +async def check_length( + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None +) -> Tuple[List[int], Optional[JSONResponse]]: + assert (not (prompt is None and prompt_ids is None) + and not (prompt is not None and prompt_ids is not None) + ), "Either prompt or prompt_ids should be provided." + if prompt_ids is not None: + input_ids = prompt_ids + else: + input_ids = tokenizer(prompt).input_ids + token_num = len(input_ids) + + if request.max_tokens is None: + request.max_tokens = max_model_len - token_num + if token_num + request.max_tokens > max_model_len: + return input_ids, create_error_response( + HTTPStatus.BAD_REQUEST, + f"This model's maximum context length is {max_model_len} tokens. " + f"However, you requested {request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", + ) + else: + return input_ids, None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.get("/v1/models") +async def show_available_models(): + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=served_model, + root=served_model, + permission=[ModelPermission()]) + ] + return ModelList(data=model_cards) + + +def create_logprobs(token_ids: List[int], + id_logprobs: List[Dict[int, float]], + initial_text_offset: int = 0) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + for token_id, id_logprob in zip(token_ids, id_logprobs): + token = tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(id_logprob[token_id]) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + + last_token_len) + last_token_len = len(token) + + logprobs.top_logprobs.append({ + tokenizer.convert_ids_to_tokens(i): p + for i, p in id_logprob.items() + }) + return logprobs + + +@app.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (to be supported by vLLM engine) + """ + logger.info(f"Received chat completion request: {request}") + + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.logit_bias is not None and len(request.logit_bias) > 0: + # TODO: support logit_bias in vLLM engine. + return create_error_response(HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + + prompt = await get_gen_prompt(request) + token_ids, error_check_ret = await check_length(request, prompt=prompt) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + try: + # spaces_between_special_tokens = request.spaces_between_special_tokens + sampling_params = SamplingParams( + n=request.n, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + temperature=request.temperature, + top_p=request.top_p, + stop=request.stop, + stop_token_ids=request.stop_token_ids, + max_tokens=request.max_tokens, + best_of=request.best_of, + top_k=request.top_k, + ignore_eos=request.ignore_eos, + use_beam_search=request.use_beam_search, + skip_special_tokens=request.skip_special_tokens, + # spaces_between_special_tokens=spaces_between_special_tokens, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + result_generator = engine.generate(prompt, sampling_params, request_id, + token_ids) + + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=text), + finish_reason=finish_reason, + ) + response = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.json(ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + response_json = create_stream_response_json( + index=i, + text=delta_text, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + response_json = create_stream_response_json( + index=i, + text="", + finish_reason=output.finish_reason, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if request.stream: + return StreamingResponse(completion_stream_generator(), + media_type="text/event-stream") + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + "Client disconnected") + final_res = res + assert final_res is not None + choices = [] + for output in final_res.outputs: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role="assistant", content=output.text), + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(fake_stream_generator(), + media_type="text/event-stream") + + return response + + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following features: + - echo (since the vLLM engine does not currently support + getting the logprobs of prompt tokens) + - suffix (the language models we currently support do not support + suffix) + - logit_bias (to be supported by vLLM engine) + """ + logger.info(f"Received completion request: {request}") + + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.echo: + # We do not support echo since the vLLM engine does not + # currently support getting the logprobs of prompt tokens. + return create_error_response(HTTPStatus.BAD_REQUEST, + "echo is not currently supported") + + if request.suffix is not None: + # The language models we currently support do not support suffix. + return create_error_response(HTTPStatus.BAD_REQUEST, + "suffix is not currently supported") + + if request.logit_bias is not None and len(request.logit_bias) > 0: + # TODO: support logit_bias in vLLM engine. + return create_error_response(HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + + use_token_ids = False + if isinstance(request.prompt, list): + if len(request.prompt) == 0: + return create_error_response(HTTPStatus.BAD_REQUEST, + "please provide at least one prompt") + first_element = request.prompt[0] + if isinstance(first_element, int): + use_token_ids = True + prompt = request.prompt + elif isinstance(first_element, (str, list)): + # TODO: handles multiple prompt case in list[list[int]] + if len(request.prompt) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported") + use_token_ids = not isinstance(first_element, str) + prompt = request.prompt[0] + else: + prompt = request.prompt + + if use_token_ids: + _, error_check_ret = await check_length(request, prompt_ids=prompt) + else: + token_ids, error_check_ret = await check_length(request, prompt=prompt) + if error_check_ret is not None: + return error_check_ret + + created_time = int(time.monotonic()) + try: + # spaces_between_special_tokens = request.spaces_between_special_tokens + sampling_params = SamplingParams( + n=request.n, + best_of=request.best_of, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + stop=request.stop, + stop_token_ids=request.stop_token_ids, + ignore_eos=request.ignore_eos, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + use_beam_search=request.use_beam_search, + skip_special_tokens=request.skip_special_tokens, + # spaces_between_special_tokens=spaces_between_special_tokens, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + if use_token_ids: + result_generator = engine.generate(None, + sampling_params, + request_id, + prompt_token_ids=prompt) + else: + result_generator = engine.generate(prompt, sampling_params, request_id, + token_ids) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + def create_stream_response_json( + index: int, + text: str, + logprobs: Optional[LogProbs] = None, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.json(ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + if request.logprobs is not None: + logprobs = create_logprobs( + output.token_ids[previous_num_tokens[i]:], + output.logprobs[previous_num_tokens[i]:], + len(previous_texts[i])) + else: + logprobs = None + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + response_json = create_stream_response_json( + index=i, + text=delta_text, + logprobs=logprobs, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + logprobs = (LogProbs() + if request.logprobs is not None else None) + response_json = create_stream_response_json( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if stream: + return StreamingResponse(completion_stream_generator(), + media_type="text/event-stream") + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + "Client disconnected") + final_res = res + assert final_res is not None + choices = [] + for output in final_res.outputs: + if request.logprobs is not None: + logprobs = create_logprobs(output.token_ids, output.logprobs) + else: + logprobs = None + choice_data = CompletionResponseChoice( + index=output.index, + text=output.text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(fake_stream_generator(), + media_type="text/event-stream") + + return response + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument("--allow-credentials", + action="store_true", + help="allow credentials") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="allowed origins") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="allowed methods") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="allowed headers") + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.") + + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + logger.info(f"args: {args}") + + if args.served_model_name is not None: + served_model = args.served_model_name + else: + served_model = args.model + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + engine_model_config = asyncio.run(engine.get_model_config()) + max_model_len = engine_model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + tokenizer = get_tokenizer(engine_args.tokenizer, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code) + + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) \ No newline at end of file diff --git a/configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml b/configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c6c91c86d44fc97dd09e13d4d2640082e5ad19e --- /dev/null +++ b/configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml @@ -0,0 +1,23 @@ +# ReAct Agent Template +name: react_template +version: 0.0.1 +type: react +description: A react agent capable of code interpreter +module_name: infiagent.agent.react +class_name: AsyncReactAgent +target_tasks: + - code interpreter +llm: + model_name: gpt-35-turbo + module_name: in f i a gen r.llm + class_name: AzureOpenAIGPTClient + params: + temperature: 0.2 + top_p: 0.95 + repetition_penalty: 1.0 + max_tokens: 4096 +prompt_template: !prompt ZeroShotReactPrompt +plugins: + - name: python_code_sandbox + type: tool + config: configs/tool_configs/async_python_code_sandbox.yaml \ No newline at end of file diff --git a/configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml b/configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78fafddf063eef5f52b9091f3d78feca8b9d9821 --- /dev/null +++ b/configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml @@ -0,0 +1,23 @@ +# ReAct Agent Template +name: gpt_4_react +version: 0.0.1 +type: react +description: A react agent capable of code interpreter +module_name: infiagent.agent.react +class_name: AsyncReactAgent +target_tasks: + - code interpreter +llm: + model_name: gpt-4-0613 + module_name: infiagent.llm + class_name: AzureOpenAIGPTClient + params: + temperature: 0.2 + top_p: 0.95 + repetition_penalty: 1.0 + max_tokens: 4096 +prompt_template: !prompt ZeroShotReactPrompt +plugins: + - name: python_code_sandbox + type: tool + config: configs/tool_configs/async_python_code_sandbox.yaml \ No newline at end of file diff --git a/configs/agent_configs/react_agent_azureopenai_gpt_4_async_dcoker.yaml b/configs/agent_configs/react_agent_azureopenai_gpt_4_async_dcoker.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2364bea0638916778ceb9030bfca6e9acfd9f94f --- /dev/null +++ b/configs/agent_configs/react_agent_azureopenai_gpt_4_async_dcoker.yaml @@ -0,0 +1,23 @@ +# ReAct Agent Template +name: gpt_4_react +version: 0.0.1 +type: react +description: A react agent capable of code interpreter +module_name: infiagent.agent.react +class_name: AsyncReactAgent +target_tasks: + - code interpreter +llm: + model_name: gpt-4-0613 + module_name: infiagent.llm + class_name: AzureOpenAIGPTClient + params: + temperature: 0.2 + top_p: 0.95 + repetition_penalty: 1.0 + max_tokens: 4096 +prompt_template: !prompt ZeroShotReactPrompt +plugins: + - name: python_code_sandbox + type: tool + config: configs/tool_configs/async_python_code_sandbox_docker.yaml \ No newline at end of file diff --git a/configs/agent_configs/react_agent_gpt4_async.yaml b/configs/agent_configs/react_agent_gpt4_async.yaml new file mode 100644 index 0000000000000000000000000000000000000000..668d119bef42e93ece5df95451f825b32da6d950 --- /dev/null +++ b/configs/agent_configs/react_agent_gpt4_async.yaml @@ -0,0 +1,23 @@ +# ReAct Agent Template +name: react_template +version: 0.0.1 +type: react +description: A react agent capable of code interpreter +module_name: infiagent.agent.react +class_name: AsyncReactAgent +target_tasks: + - code interpreter +llm: + model_name: gpt-4 + module_name: infiagent.llm + class_name: OpenAIGPTClient + params: + temperature: 0.0 + top_p: 0.9 + repetition_penalty: 1.0 + max_tokens: 1024 +prompt_template: !prompt ZeroShotReactPrompt +plugins: + - name: python_code_sandbox + type: tool + config: configs/tool_configs/async_python_code_sandbox.yaml \ No newline at end of file diff --git a/configs/agent_configs/react_agent_llama_async.yaml b/configs/agent_configs/react_agent_llama_async.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60105133d06f61e574123cf1f3a6a9fc382081f7 --- /dev/null +++ b/configs/agent_configs/react_agent_llama_async.yaml @@ -0,0 +1,23 @@ +# ReAct Agent Template +name: react_template +version: 0.0.1 +type: react +description: A react agent capable of code interpreter +module_name: infiagent.agent.react +class_name: AsyncReactAgent +target_tasks: + - code interpreter +llm: + model_name: meta-llama/Llama-2-7b-hf + module_name: infiagent.llm + class_name: LlamaOpenAIClient + params: + temperature: 0.0 + top_p: 0.9 + repetition_penalty: 1.0 + max_tokens: 1024 +prompt_template: !prompt ZeroShotReactPrompt +plugins: + - name: python_code_sandbox + type: tool + config: configs/tool_configs/async_python_code_sandbox.yaml \ No newline at end of file diff --git a/configs/agent_configs/react_agent_opt_async.yaml b/configs/agent_configs/react_agent_opt_async.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7962afacc6bae6dfec7fd04938efc2648ba07308 --- /dev/null +++ b/configs/agent_configs/react_agent_opt_async.yaml @@ -0,0 +1,23 @@ +# ReAct Agent Template +name: react_template +version: 0.0.1 +type: react +description: A react agent capable of code interpreter +module_name: infiagent.agent.react +class_name: AsyncReactAgent +target_tasks: + - code interpreter +llm: + model_name: facebook/opt-125m + module_name: infiagent.llm + class_name: OptOpenAIClient + params: + temperature: 0.0 + top_p: 0.9 + repetition_penalty: 1.0 + max_tokens: 1024 +prompt_template: !prompt ZeroShotReactPrompt +plugins: + - name: python_code_sandbox + type: tool + config: configs/tool_configs/async_python_code_sandbox.yaml \ No newline at end of file diff --git a/configs/tool_configs/async_python_code_sandbox.yaml b/configs/tool_configs/async_python_code_sandbox.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8fc0791b40de7f21b189e23a4be8082c60aeba8 --- /dev/null +++ b/configs/tool_configs/async_python_code_sandbox.yaml @@ -0,0 +1,7 @@ +name: python_code_sandbox +version: 0.0.1 +type: tool +description: this tool can help to run python script with python code as input +module_name: infiagent.tools +class_name: AsyncPythonSandBoxTool +session_id: none \ No newline at end of file diff --git a/configs/tool_configs/async_python_code_sandbox_docker.yaml b/configs/tool_configs/async_python_code_sandbox_docker.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7e13b99c30cdb1aa544e8137f60e8b5c8758de0 --- /dev/null +++ b/configs/tool_configs/async_python_code_sandbox_docker.yaml @@ -0,0 +1,7 @@ +name: python_code_sandbox +version: 0.0.1 +type: tool +description: this tool can help to run python script with python code as input +module_name: infiagent.tools +class_name: CodeTool +session_id: none \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..db64c0261ced33f24ffda9ab9057db0f872f4b2e --- /dev/null +++ b/run.sh @@ -0,0 +1,3 @@ +#!/bin/bash +set -ex +poetry run python3 -m uvicorn src.activities.api:app --reload --host 0.0.0.0 --port ${PORT:-3000} --limit-max-requests 5000 --timeout-keep-alive 1200 diff --git a/run_demo.sh b/run_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..d684fada4c9257281cb28902864bf7a57fbf2e8b --- /dev/null +++ b/run_demo.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# set -ex + +streamlit run ./activities/local_demo.py --server.port 6006 -- $@ + diff --git a/run_local.sh b/run_local.sh new file mode 100644 index 0000000000000000000000000000000000000000..99b9c4f0b7d37edba2c787567b90a78f41e5bc00 --- /dev/null +++ b/run_local.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -ex +poetry run python3 -m uvicorn src.activities.local_test:local_app --reload --host 0.0.0.0 --port ${PORT:-3000} --limit-max-requests 5000 --timeout-keep-alive 1200 + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6f9aee2eb686aec0d3d70021ffcd25e1ab3002be --- /dev/null +++ b/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup, find_packages + +setup( + name='infiagent', + version='0.1.0', + author='InfiAgent', + packages=find_packages(where='src'), + package_dir={'': 'src'}, + url='https://github.com/InfiAgent/ADA-Agent', + license='LICENSE.txt', + description='An awesome package for InfiAgent.', + long_description=open('README.md').read(), + package_data={ + 'infiagent.configs.agent_configs': ['*.yaml'], + 'infiagent.configs.tool_configs': ['*.yaml'], + }, + install_requires=[ + "streamlit", + "pyyaml", + "pytest", + "openai==0.27.7", + "fastapi", + "uvicorn", + "uvloop", + "watchdog", + "chardet", + "werkzeug", + "python-dotenv", + "motor", + "aiofiles", + "sse_starlette", + "loguru", + "jupyter_client", + "pandas", + "scikit-learn", + "scipy", + "ipykernel" + ], + python_requires='>=3.9' +) diff --git a/src/infiagent/__init__.py b/src/infiagent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/infiagent/agent/__init__.py b/src/infiagent/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54752e1350f00bf9ae6469dcd52db66561d1bc2b --- /dev/null +++ b/src/infiagent/agent/__init__.py @@ -0,0 +1,2 @@ +from .base_agent import BaseAgent +from .react import AsyncReactAgent diff --git a/src/infiagent/agent/base_agent.py b/src/infiagent/agent/base_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..790e98356a75d175e90495fc5bb8a47bb70189a5 --- /dev/null +++ b/src/infiagent/agent/base_agent.py @@ -0,0 +1,337 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import Dict, Callable, Union, AsyncGenerator + +from ..exceptions.exceptions import InputErrorException +from ..prompt import PromptTemplate +from ..schemas import AgentOutput, AgentType, AgentResponse + +from ..llm.base_llm import BaseLLM + +from ..tools import BaseTool +from ..utils import Config, get_logger + +import os +from importlib import import_module + +logger = get_logger() + + +LLM_CONF_OVERRIDE_KEY = ['psm', 'dc', 'temperature', 'top_p', 'top_k', 'max_tokens'] + + +class BaseAgent(ABC): + """Base Agent class defining the essential attributes and methods for an ALM Agent. + """ + + def __init__(self, **kwargs): + """ + Initializes an instance of the Agent class. + """ + # Set default values + default_config = { + 'name': 'agent', + 'type': AgentType.react, + 'version': '', + 'description': '', + 'prompt_template': None, + 'auth': {} + } + # Update default values with provided config + default_config.update(kwargs) + + # Access configuration data with a known default value + auth = default_config['auth'] + self._set_auth_env(auth) + + self._name: str = default_config['name'] + self._type: AgentType = default_config['type'] + self._version: str = default_config['version'] + self._description: str = default_config['description'] + self.__prompt_template: Union[PromptTemplate, None] = \ + self._get_prompt_template(default_config['prompt_template']) + self.__llm: Union[BaseLLM, None] = None + self.__plugins_map: Dict = {} + self.__plugin_tool_function = {} + self.__plugin_tool_async_function = {} + self.__plugin_tool_description = None + + @property + def name(self) -> str: + return self._name + + @property + def type(self) -> AgentType: + return self._type + + @property + def version(self) -> str: + return self._version + + @property + def description(self) -> str: + return self._description + + @property + def prompt_template(self) -> PromptTemplate: + return self.__prompt_template + + @property + def llm(self) -> Union[BaseLLM, None]: + return self.__llm + + @llm.setter + def llm(self, llm_client: BaseLLM): + if llm_client is None or not isinstance(llm_client, BaseLLM): + raise InputErrorException("Invalid llm client {}".format(type(llm_client))) + self.__llm = llm_client + + @property + def plugins_map(self) -> Dict: + return self.__plugins_map.copy() # Return a copy to prevent external modification + + def add_plugin(self, tool_name: str, tool): + if not tool_name or not tool: + raise InputErrorException("Adding invalid tool name: {}, type {}".format(tool_name, type(tool))) + self.__plugins_map[tool_name] = tool + + def _set_auth_env(self, obj): + """This method sets environment variables for authentication. + """ + for key in obj: + os.environ[key] = obj.get(key) + + def _get_prompt_template(self, obj): + """This method returns a prompt template instance based on the provided configuration. + """ + assert isinstance(obj, dict) or isinstance(obj, PromptTemplate) + if isinstance(obj, dict): + return { + key: self._parse_prompt_template(obj[key]) for key in obj + } + elif isinstance(obj, PromptTemplate): + ans = self._parse_prompt_template(obj) + return ans + else: + raise InputErrorException("Invalid PromptTemplate, it should be a dict or PromptTemplate. But get {}" + .format(type(obj))) + + def _parse_prompt_template(self, obj: Union[dict, PromptTemplate]): + """This method parses the prompt template configuration and returns a prompt template instance. + """ + assert isinstance(obj, dict) or isinstance(obj, PromptTemplate) + if isinstance(obj, PromptTemplate): + return obj + return PromptTemplate(input_variables=obj['input_variables'], + template=obj['template'], + validate_template=bool(obj.get('validate_template', True))) + + @classmethod + def _get_basic_instance_from_config(cls, config_data): + agent_module_name = config_data.get("module_name", None) + agent_class_name = config_data.get("class_name", None) + if not agent_module_name or not agent_class_name: + raise InputErrorException("Agent module_name and class_name required, please check your config") + + module = import_module(agent_module_name) + clazz = getattr(module, agent_class_name) + agent_instance = clazz(**config_data) + return agent_instance + + @classmethod + def from_config_path_and_kwargs(cls, config_path, **kwargs): + config_data = Config.load(config_path) + logger.info(f"Use config from path {config_path} to init agent : {config_data}") + agent_instance = cls._get_basic_instance_from_config(config_data) + + if 'llm' in config_data and 'params' in config_data['llm']: + for param in LLM_CONF_OVERRIDE_KEY: + if param in kwargs and kwargs[param]: + logger.info(f"Overwrite with new {param} {kwargs[param]}") + config_data['llm']['params'][param] = kwargs[param] + + assert isinstance(agent_instance, BaseAgent) + agent_instance._init_llm(config_data.get("llm", {})) + agent_instance._init_plugins(config_data.get('plugins', [])) + return agent_instance + + def _init_llm(self, obj): + """ + This method parses the Language Model Manager (LLM) configuration and returns an LLM instance. + + :param obj: A configuration dictionary or string. + :type obj: dict or str + :raises ValueError: If the specified LLM is not supported. + :return: An LLM instance. + :rtype: BaseLLM + """ + if isinstance(obj, str): + name = obj + model_params = dict() + else: + name = obj.get('model_name', None) + model_params = obj.get('params', dict()) + + module_name = obj['module_name'] + class_name = obj['class_name'] + + module = import_module(module_name) + clazz = getattr(module, class_name) + + llm = clazz(model_name=name, params=model_params) + self.llm = llm + + def _init_plugins(self, configs): + """ + This method parses the plugin configuration and add each plugin into the plugins_map. + """ + assert isinstance(configs, list) + for plugin_config in configs: + if plugin_config.get('type', "") == 'agent': + # Agent as plugin + agent = BaseAgent.from_config_path_and_kwargs(plugin_config['config']) + self.plugins_map[plugin_config['name']] = agent + else: + # Tools as plugin + params = plugin_config.get('params', dict()) + tool = BaseTool.from_config(config_input=plugin_config['config'], **params) + self.plugins_map[tool.name] = tool + + @classmethod + async def async_from_config_path_and_kwargs(cls, config_path, **kwargs): + config_data = Config.load(config_path) + logger.info(f"Use config from path {config_path} to init agent : {config_data}") + agent_instance = cls._get_basic_instance_from_config(config_data) + + # override default config with user input + if 'llm' in config_data and 'params' in config_data['llm']: + for param in LLM_CONF_OVERRIDE_KEY: + if param in kwargs and kwargs[param]: + logger.info(f"Overwrite with new {param} {kwargs[param]}") + config_data['llm']['params'][param] = kwargs[param] + + # Create tasks for llm and each individual plugin + llm_config = config_data.get("llm", {}) + plugin_configs = config_data.get('plugins', []) + + + # Create tasks for llm and each individual plugin + llm_task = asyncio.create_task(cls._async_init_llm(llm_config)) + plugin_tasks = [asyncio.create_task(cls._async_init_plugin(plugin_config)) for + plugin_config in plugin_configs] + + + # Gather results + llm, *plugins = await asyncio.gather(llm_task, *plugin_tasks) + + agent_instance.llm = llm + for plugin in plugins: + plugin_name, plugin_instance = plugin + agent_instance.add_plugin(plugin_name, plugin_instance) + return agent_instance + + @classmethod + async def _async_init_llm(cls, llm_config): + llm_model_name = llm_config.get("module_name", None) + llm_class_name = llm_config.get("class_name", None) + if not llm_model_name or not llm_class_name: + raise InputErrorException("Agent LLM module_name and class_name required, please check your config") + module = import_module(llm_model_name) + clazz = getattr(module, llm_class_name) + assert issubclass(clazz, BaseLLM), f"{clazz} is not a subclass of BaseLLM" + llm_instance = await clazz.create(config_data=llm_config) + return llm_instance + + @classmethod + async def _async_init_plugin(cls, plugin_config): + + if plugin_config.get('type', "") == 'agent': + # Agent as plugin + agent = await BaseAgent.async_from_config_path_and_kwargs(plugin_config['config']) + return plugin_config['name'], agent + else: + # Tool as plugin + params = plugin_config.get('params', dict()) + name = plugin_config.get('name', None) + config = plugin_config['config'] + + tool = await BaseTool.async_from_config(config_input=config, **params) + + if name is None: + name = tool.name + logger.info("Init tool with name [{}], and description [{}]".format(name, tool.description)) + return name, tool + + @abstractmethod + def run(self, *args, **kwargs) -> [AgentResponse, None]: + """Abstract method to be overridden by child classes for running the agent. + + :return: The output of the agent. + :rtype: AgentOutput + """ + pass + + async def async_run(self, *args, **kwargs) -> AsyncGenerator[AgentResponse, None]: + """Abstract method to be overridden by child classes for running the agent. + + :return: The output of the agent. + """ + yield self.run(*args, **kwargs) + + def _get_plugin_function_map(self, method_name: str) -> Dict[str, Callable]: + if method_name == "run" and self.__plugin_tool_function: + return self.__plugin_tool_function + elif method_name == "async_run" and self.__plugin_tool_async_function: + return self.__plugin_tool_async_function + + function_map = {} + + for name, plugin_tool in self.plugins_map.items(): + if isinstance(plugin_tool, (BaseTool, BaseAgent)): + function_map[name] = getattr(plugin_tool, method_name) + else: + logger.warning(f"No support for plugin name {name} of type {type(plugin_tool)}") + + if method_name == "run": + self.__plugin_tool_function = function_map + elif method_name == "async_run": + self.__plugin_tool_async_function = function_map + + return function_map + + def get_plugin_tool_function(self) -> Dict[str, Callable]: + """Format the function map for the function API. + + :return: The function map. + :rtype: Dict[str, Callable] + """ + return self._get_plugin_function_map("run") + + def get_plugin_tool_async_function(self) -> Dict[str, Callable]: + """Format the function map for the function API. + + :return: The function map. + :rtype: Dict[str, Callable] + """ + return self._get_plugin_function_map("async_run") + + def _get_plugin_description(self): + if self.__plugin_tool_description: + return self.__plugin_tool_description + + descriptions = "" + try: + for plugin_name, plugin in self.plugins_map.items(): + descriptions += f"{plugin_name}[input]: {plugin.description}\n" + except Exception as e: + err_msg = "Failed to get plugin tool name and description. error: {}".format(str(e)) + raise InputErrorException(err_msg) from e + + self.__plugin_tool_description = descriptions + return descriptions + + def clear(self): + """ + Clear and reset the agent. + """ + pass diff --git a/src/infiagent/agent/react/__init__.py b/src/infiagent/agent/react/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb3908f571a843ffbbd2e02ba381ed3054d9278 --- /dev/null +++ b/src/infiagent/agent/react/__init__.py @@ -0,0 +1,4 @@ +from .async_react_agent import AsyncReactAgent +__all__ = [ + 'AsyncReactAgent' +] \ No newline at end of file diff --git a/src/infiagent/agent/react/async_react_agent.py b/src/infiagent/agent/react/async_react_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7d70b37d49e4ec5dde8961bd0ecc1cd9016f01 --- /dev/null +++ b/src/infiagent/agent/react/async_react_agent.py @@ -0,0 +1,299 @@ +import re +import time +from typing import Union, List, Dict + +from werkzeug.datastructures import FileStorage + +from .. import BaseAgent +from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException +from ...schemas import ( + AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse, + BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile +) +from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool +from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \ + OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \ + TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN + +SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' +FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] +CODE_BLOCK_START_TAG = '```python' +CODE_BLOCK_TAG = '```' + +logger = get_logger() + +SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' +FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] +CODE_BLOCK_START_TAG = '```python' +CODE_BLOCK_TAG = '```' +STOP_WORD = ['Observation:'] + +logger = get_logger() + + +class AsyncReactAgent(BaseAgent): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._name = self._name or "AsyncReactAgent" + self._type = AgentType.react + self.__intermediate_steps: List[BaseAgentResponse] = [] + + @property + def intermediate_steps(self): + return self.__intermediate_steps + + def run(self, *args, **kwargs): + pass + + async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]): + sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME) + if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)): + raise InternalErrorException("SandBox client is not ready for agent, please check init logic.") + return await sandbox_plugin.sync_to_sandbox(file) + + async def async_run(self, agent_req: AgentRequest): + instruction = '\n'.join(message.content for message in agent_req.messages) + async for response in self._chat(instruction, is_cn=agent_req.is_cn): + yield response + + async def _chat(self, instruction: str, is_cn=False, max_iterations=10, + max_single_step_iterations=3): + current_iteration = 0 + + for _ in range(max_iterations): + current_iteration += 1 + llm_response = await self._single_round_thought(instruction, + max_llm_iteration=max_single_step_iterations, + is_cn=is_cn) + logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n" + .format(current_iteration, max_iterations, llm_response.raw_output, + llm_response.formatted_output)) + yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output) + + if isinstance(llm_response, AgentFinish): + logger.info("Find final answer, stop iteration.") + break + + self.intermediate_steps.append(llm_response) + action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration, + max_iterations, is_cn) + logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n" + .format(current_iteration, max_iterations, action_response.raw_output, + action_response.formatted_output)) + self.intermediate_steps.append(action_response) + + yield self.create_agent_response(action_response.formatted_output, + cur_output_files, + action_response.raw_output) + + logger.info(f"Finished iteration in {current_iteration}.") + + # TODO update logic to not be sandbox specific, sandbox related logic should be handled in sandbox client + async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False): + try: + response.tool = 'python_code_sandbox' + action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input) + logger.info( + f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n" + f"{action_response.output_text}") + + if "STDERR" in action_response.output_text: + formatted_output = self._process_sandbox_output(action_response.output_text) + else: + formatted_output = action_response.output_text + + formatted_output = replace_latex_format(formatted_output) + observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN + formatted_output = f"{observation_prefix}\n{formatted_output}\n" + + action_observation = AgentObservation(tool=response.tool, + formatted_output=formatted_output, + raw_output=action_response.output_text) + cur_output_files = self._get_output_files(action_response) + return action_observation, cur_output_files + + except Exception as e: + logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. " + f"Error: {str(e)}", exc_info=True) + # TODO: We hard code here as we only have one tool + raise SandboxException("Error occurred while running the tool") from e + + def _compose_prompt(self, instruction) -> str: + """ + Compose the prompt from template, worker description, examples and instruction. + """ + agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps) + tool_description = self._get_plugin_description() + tool_names = ", ".join(list(self.plugins_map.keys())) + if self.prompt_template is None: + raise InternalErrorException("Agent prompt is none, please check init process") + + return self.prompt_template.format( + instruction=instruction, + agent_scratchpad=agent_scratchpad, + tool_description=tool_description, + tool_names=tool_names + ) + + async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \ + Union[AgentAction, AgentFinish]: + + llm_iteration_count = 0 + + llm_response = None + while llm_iteration_count <= max_llm_iteration: + llm_iteration_count += 1 + try: + llm_response = await self._get_llm_response(instruction) + action_response = self._parse_output(llm_response.content, is_cn) + + return action_response + except Exception as e: + logger.error("LLM iteration {} out of {} failed. Error: {}". + format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True) + + if llm_iteration_count > max_llm_iteration: + logger.error("LLM iteration {} exceed max retry {}. Aborting". + format(llm_iteration_count, max_llm_iteration)) + return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN, + raw_output=str(llm_response)) + + async def _get_llm_response(self, instruction: str): + prompt = self._compose_prompt(instruction) + logger.info("Send prompt to LLM:\n{}".format(prompt)) + response = await self.llm.async_completion(prompt) + if response.state == "error": + raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content))) + + logger.info("Got response from llm, raw response content: \n{}".format(response.content)) + return response + + def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]: + + for stop_word in STOP_WORD: + if stop_word in llm_output: + llm_output = llm_output.split(stop_word)[0].rstrip() + break + + # Check for Final Answer, if it is final, then just return + for indicator in FINAL_ANSWER_INDICATORS: + if indicator in llm_output: + # got final answer and remove the indicator + parts = llm_output.split(indicator) + # formatted_output = ''.join(parts[:-1]).strip() + formatted_output = ''.join(parts).strip() + formatted_output = replace_latex_format(formatted_output) + return AgentFinish(raw_output=llm_output, formatted_output=formatted_output) + + # Updated regex pattern for capturing the expected input format + ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" + ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" + + action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL) + + # Find action, context, and action input, build action response + if action_match: + context = action_match.group(1).strip() + action_tool_description = action_match.group(2).strip() + action_input = action_match.group(3).strip() + + # Format code + # TODO: currently we only have one plugin which is sandbox, update to support multiple tools + format_code_block = self._format_code_block(action_input) + + prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN + formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block) + formatted_output = replace_latex_format(formatted_output) + + return AgentAction(tool=action_tool_description, + tool_input=format_code_block, + formatted_output=formatted_output, + raw_output=llm_output) + + # Not final answer and not action, raise exception + if not re.search(r"Action\s*:", llm_output, re.DOTALL): + raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`") + elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL): + raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`") + else: + raise LLMException(f"Unrecognized LLM output format: `{llm_output}`") + + def _format_code_block(self, tool_input): + stripped_tool_input = tool_input.strip() + + if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): + if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'): + stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \ + '\n' + formatted_code = stripped_tool_input + elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith( + CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): + formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n' + else: + formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n' + + return formatted_code.encode("utf-8").decode("utf-8") + + def _process_sandbox_output(self, output: str): + """Function to process the result containing STDERR.""" + if len(output) <= 1000: + return output + + logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}". + format(output)) + rows = output.split("\n") + # Get the first 500 characters, respecting line boundaries + top_segment = [] + length = 0 + for sub_p in rows: + if length + len(sub_p) > 500: + break + top_segment.append(sub_p) + length += len(sub_p) + + # Get the last 500 characters, respecting line boundaries + bottom_segment = [] + length = 0 + for sub_p in reversed(rows): + if length + len(sub_p) > 500: + break + bottom_segment.insert(0, sub_p) + length += len(sub_p) + + # Combine the segments with "......" in between + timed_output = "\n".join(top_segment + ["......"] + bottom_segment) + + return timed_output + + def _get_output_files(self, tool_response) -> list[MediaFile]: + output_files = [] + + if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput): + raw_output = tool_response.raw_output + + if raw_output.code == 0 and not raw_output.data.is_partial: + result_data = raw_output.data.result + + # TODO confirm if we still need output and format + if len(result_data.new_generated_files) > 0: + output_files.extend([MediaFile(tos_path=file.download_link) for file in + result_data.new_generated_files]) + + if len(result_data.code_output_result) > 0: + output_files.extend( + [MediaFile(tos_path=image.content) for image in result_data.code_output_result + if image.type == 'image']) + + return output_files + + def _replace_csv_path(self, input_string): + # Search for the pattern and replace it + pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)' + replacement = "pd.read_csv('/path/to/your/dataset')" + updated_string = re.sub(pattern, replacement, input_string) + return updated_string + + @staticmethod + def create_agent_response(formatted_output, output_files, raw_output): + return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output) + diff --git a/src/infiagent/conversation_sessions/__init__.py b/src/infiagent/conversation_sessions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..684e6b7427f2ba0dc7621fa7851113f0dd4c2fc2 --- /dev/null +++ b/src/infiagent/conversation_sessions/__init__.py @@ -0,0 +1 @@ +from .code_interpreter_session import CodeInterpreterSession diff --git a/src/infiagent/conversation_sessions/code_interpreter_session.py b/src/infiagent/conversation_sessions/code_interpreter_session.py new file mode 100644 index 0000000000000000000000000000000000000000..7df64e543ba5fe3845ff0b028fd335b98a30d34e --- /dev/null +++ b/src/infiagent/conversation_sessions/code_interpreter_session.py @@ -0,0 +1,87 @@ +import logging +import os +import time +from typing import Any, Dict, Union + +from werkzeug.datastructures import FileStorage + +from ..agent import BaseAgent +from ..agent.react import AsyncReactAgent +from ..schemas import AgentRequest, MediaFile, Message, RoleType +from ..utils import generate_random_string, get_logger, get_model_config_path + +logger = get_logger() + + +class CodeInterpreterSession: + def __init__( + self, + session_id: Union[None, str] = None, + model_name: Union[None, str] = "openai", + config_path: Union[None, str] = None, + agent: AsyncReactAgent = None, + **kwargs): + self.session_id = session_id + self.config_path = config_path + self.input_files = [] + self.output_files = [] + self.messages = [] + self.agent = agent + self.llm_model_name = self.agent.llm.model_name + + logger.info("Use model {} and llm in config {} for conversation {}" + .format(model_name, self.llm_model_name, self.config_path, self.session_id)) + + @classmethod + async def create(cls, + model_name: Union[None, str] = "openai", + config_path: Union[None, str] = None, + **kwargs: Dict[str, Any]): + if config_path is None: + config_path = get_model_config_path(model_name) + logger.info(f"Use Config Path: {config_path}") + + sandbox_id = generate_random_string(12) + + # setup agent + agent = await BaseAgent.async_from_config_path_and_kwargs(config_path, **kwargs) + await agent.plugins_map["python_code_sandbox"].set_sandbox_id(sandbox_id) + + return cls(session_id=sandbox_id, + model_name=model_name, + config_path=config_path, + agent=agent) + + async def upload_to_sandbox(self, file: Union[str, FileStorage]): + dst_path = await self.agent.sync_to_sandbox(file) + message = f'User uploaded the following files: {dst_path}\n' + logging.info(f"The file path {file} has been synced to sandbox with file path {dst_path}") + self.messages.append(Message(RoleType.System, message)) + self.input_files.append(MediaFile(file_name=os.path.basename(dst_path), sandbox_path=dst_path)) + + async def chat(self, user_messages, input_files=None): + start_time = time.time() + + self.messages.extend(user_messages) + agent_request = AgentRequest( + messages=self.messages, + input_files=self.input_files, + sandbox_id=self.session_id + ) + logger.info(f"Agent request: {agent_request.__dict__}") + + async for agent_response in self.agent.async_run(agent_request): + logger.info(f"Agent response:\n{agent_response.output_text}") + self.messages.append(Message(RoleType.System, agent_response.output_text)) + yield agent_response + + exec_time = time.time() + logger.info( + f'Agent Execution Latency: {exec_time - start_time}' + ) + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback) -> None: + pass diff --git a/src/infiagent/exceptions/__init__.py b/src/infiagent/exceptions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/infiagent/exceptions/exceptions.py b/src/infiagent/exceptions/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0f64904b02a73998b46bdfc8990b4c4e828511 --- /dev/null +++ b/src/infiagent/exceptions/exceptions.py @@ -0,0 +1,46 @@ +class DependencyException(Exception): + pass + + +class InputErrorException(Exception): + pass + + +class InternalErrorException(Exception): + pass + + +class DatabaseException(DependencyException): + def __init__(self, message, *args: object): + super().__init__(message, *args) + + +class SandboxException(DependencyException): + def __init__(self, message, *args: object): + super().__init__(message, *args) + + +class LLMException(DependencyException): + def __init__(self, message, *args: object): + super().__init__(message, *args) + + +class ModelMaxIterationsException(DependencyException): + def __init__(self, message, *args: object): + super().__init__(message, *args) + + +class InvalidConfigException(InputErrorException): + def __init__(self, message, *args: object): + super().__init__(message, *args) + + +class SandBoxFileUploadException(SandboxException): + def __init__(self, message, *args: object): + super().__init__(message, *args) + + +class PluginException(DependencyException): + def __init__(self, message, *args: object): + super().__init__(message, *args) + diff --git a/src/infiagent/llm/__init__.py b/src/infiagent/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1efec5c1c125ceb50c51265338ad7028fe47d59f --- /dev/null +++ b/src/infiagent/llm/__init__.py @@ -0,0 +1,5 @@ +from .client.openai import * +from .client.azure_openai import * +from .client.opt import * +from .client.llama import * +from .base_llm import * diff --git a/src/infiagent/llm/base_llm.py b/src/infiagent/llm/base_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..e41b9f9fa8db18c657358698ad918a87939344c3 --- /dev/null +++ b/src/infiagent/llm/base_llm.py @@ -0,0 +1,36 @@ +from abc import ABC + +from ..exceptions.exceptions import InputErrorException +from ..schemas import BaseCompletion + + +class BaseLLM(ABC): + + def __init__(self, model_name: str, params: dict, **kwargs): + self.__model_name = model_name + self.__params = params + + @classmethod + async def create(cls, config_data: dict): + pass + + @property + def model_name(self) -> str: + return self.__model_name + + @model_name.setter + def model_name(self, model_name): + if model_name is None: + raise InputErrorException("Invalid model_name {}".format(model_name)) + self.__model_name = model_name + + @property + def params(self) -> dict: + return self.__params + + def completion(self, prompt) -> BaseCompletion: + pass + + async def async_completion(self, prompt) -> BaseCompletion: + pass + diff --git a/src/infiagent/llm/client/__init__.py b/src/infiagent/llm/client/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/infiagent/llm/client/azure_openai.py b/src/infiagent/llm/client/azure_openai.py new file mode 100644 index 0000000000000000000000000000000000000000..6227ef217198fb3adeb74ea1be80737b06325faf --- /dev/null +++ b/src/infiagent/llm/client/azure_openai.py @@ -0,0 +1,346 @@ +import json +import logging +import os +from abc import ABC +from typing import Callable, List + +import openai +from tenacity import ( # for exponential backoff + before_sleep_log, + retry, + stop_after_attempt, + wait_random_exponential, +) + +from ..base_llm import BaseLLM +from ...schemas import * + +logger = logging.getLogger(__name__) + +MAX_PROMPT_LENGTH = 7000 + + +@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(100), reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING)) +def chatcompletion_with_backoff(**kwargs): + return openai.ChatCompletion.create(**kwargs) + + +@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(100), reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING)) +async def async_chatcompletion_with_backoff(**kwargs): + async def _internal_coroutine(): + return await openai.ChatCompletion.acreate(**kwargs) + + return await _internal_coroutine() + + +class AzureOpenAIGPTClient(BaseLLM, ABC): + """ + Wrapper class for OpenAI GPT API collections. + + :param model_name: The name of the model to use. + :type model_name: str + :param params: The parameters for the model. + :type params: AzureOpenAIParamModel + """ + + model_name: str + params: AzureOpenAIParamModel = AzureOpenAIParamModel() + + def __init__(self, **data): + super().__init__(**data) + openai.api_key = os.environ.get("OPENAI_API_KEY", "") + openai.api_type = "azure" + openai.api_base = "https://search.bytedance.net/gpt/openapi/online/v2/crawl" + openai.api_version = "2023-06-01-preview" + + @classmethod + async def create(cls, config_data): + return AzureOpenAIGPTClient(**config_data) + + def get_model_name(self) -> str: + return self.model_name + + def get_model_param(self) -> AzureOpenAIParamModel: + return self.params + + def completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + + """ + + response = chatcompletion_with_backoff( + engine=self.get_model_name(), # GPT-4 + messages=[ + {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} + ], + timeout=1000, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + **kwargs + ) + + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + + async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + + """ + response = await async_chatcompletion_with_backoff( + engine=self.get_model_name(), + messages=[ + {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} + ], + timeout=1000, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + **kwargs + ) + + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + + def chat_completion(self, message: List[dict]) -> ChatCompletion: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + response = openai.ChatCompletion.create( + engine=self.get_model_name(), # GPT-4 + messages=message, + timeout=1000, + ) + + return ChatCompletion( + state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0), + ) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def stream_chat_completion(self, message: List[dict], **kwargs): + """ + Stream output chat completion for OpenAI GPT API. + + :param message: The message (scratchpad) to use for completion. Usually contains json of role and content. + :type message: List[dict] + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + response = openai.ChatCompletion.create( + engine=self.get_model_name(), # GPT-4 + messages=message, + timeout=1000, + **kwargs, + ) + + role = next(response).choices[0].delta["role"] + messages = [] + ## TODO: Calculate prompt_token and for stream mode + for resp in response: + messages.append(resp.choices[0].delta.get("content", "")) + yield ChatCompletion( + state="success", + role=role, + content=messages[-1], + prompt_token=0, + completion_token=0, + ) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def function_chat_completion( + self, + message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict], + ) -> ChatCompletionWithHistory: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :param function_map: The function map to use for completion. + :type function_map: Dict[str, Callable] + :param function_schema: The function schema to use for completion. + :type function_schema: List[Dict] + :return: ChatCompletionWithHistory object. + :rtype: ChatCompletionWithHistory + """ + assert len(function_schema) == len(function_map) + try: + response = openai.ChatCompletion.create( + engine=self.get_model_name(), # GPT-4 + messages=message, + functions=function_schema, + timeout=1000, + ) + # response = openai.ChatCompletion.create( + # n=self.params.n, + # model=self.model_name, + # messages=message, + # functions=function_schema, + # temperature=self.params.temperature, + # max_tokens=self.params.max_tokens, + # top_p=self.params.top_p, + # frequency_penalty=self.params.frequency_penalty, + # presence_penalty=self.params.presence_penalty, + # ) + response_message = response.choices[0]["message"] + + if response_message.get("function_call"): + function_name = response_message["function_call"]["name"] + fuction_to_call = function_map[function_name] + function_args = json.loads( + response_message["function_call"]["arguments"] + ) + function_response = fuction_to_call(**function_args) + + # Postprocess function response + if isinstance(function_response, str): + plugin_cost = 0 + plugin_token = 0 + elif isinstance(function_response, AgentOutput): + plugin_cost = function_response.cost + plugin_token = function_response.token_usage + function_response = function_response.output + else: + raise Exception( + "Invalid tool response type. Must be on of [AgentOutput, str]" + ) + + message.append(dict(response_message)) + message.append( + { + "role": "function", + "name": function_name, + "content": function_response, + } + ) + second_response = openai.ChatCompletion.create( + model=self.get_model_name(), + messages=message, + ) + message.append(dict(second_response.choices[0].message)) + return ChatCompletionWithHistory( + state="success", + role=second_response.choices[0].message["role"], + content=second_response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0) + + second_response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get( + "completion_tokens", 0 + ) + + second_response.get("usage", {}).get("completion_tokens", 0), + message_scratchpad=message, + plugin_cost=plugin_cost, + plugin_token=plugin_token, + ) + else: + message.append(dict(response_message)) + return ChatCompletionWithHistory( + state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get( + "completion_tokens", 0 + ), + message_scratchpad=message, + ) + + except Exception as exception: + print("Exception:", exception) + return ChatCompletionWithHistory(state="error", content=str(exception)) + + def function_chat_stream_completion( + self, + message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict], + ) -> ChatCompletionWithHistory: + assert len(function_schema) == len(function_map) + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.get_model_name(), + messages=message, + functions=function_schema, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + stream=True, + ) + tmp = next(response) + role = tmp.choices[0].delta["role"] + _type = ( + "function_call" + if tmp.choices[0].delta["content"] is None + else "content" + ) + if _type == "function_call": + name = tmp.choices[0].delta["function_call"]["name"] + yield _type, ChatCompletionWithHistory( + state="success", + role=role, + content="{" + f'"name":"{name}", "arguments":', + message_scratchpad=message, + ) + for resp in response: + # print(resp) + content = resp.choices[0].delta.get(_type, "") + if isinstance(content, dict): + content = content["arguments"] + yield _type, ChatCompletionWithHistory( + state="success", + role=role, + content=content, + message_scratchpad=message, + ) + + except Exception as e: + logger.error(f"Failed to get response {str(e)}", exc_info=True) + raise e diff --git a/src/infiagent/llm/client/llama.py b/src/infiagent/llm/client/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..a34cd5ff130c8f65d27609baa83d0e9e341d7aa9 --- /dev/null +++ b/src/infiagent/llm/client/llama.py @@ -0,0 +1,377 @@ +import json +import logging +import os +from abc import ABC +from typing import Callable, List + +import openai +from tenacity import ( # for exponential backoff + before_sleep_log, + retry, + stop_after_attempt, + wait_random_exponential, +) + +from ..base_llm import BaseLLM +from ...schemas import * + +logger = logging.getLogger(__name__) + +MAX_PROMPT_LENGTH = 4096 + + +@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(5), reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING)) +def chatcompletion_with_backoff(**kwargs): + return openai.ChatCompletion.create(**kwargs) + + +@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(5), reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING)) +async def async_chatcompletion_with_backoff(**kwargs): + async def _internal_coroutine(): + return await openai.ChatCompletion.acreate(**kwargs) + + return await _internal_coroutine() + + +class LlamaOpenAIClient(BaseLLM, ABC): + """ + Wrapper class for OpenAI GPT API collections. + + :param model_name: The name of the model to use. + :type model_name: str + :param params: The parameters for the model. + :type params: LlamaParamModel + """ + + model_name: str + params: LlamaParamModel = LlamaParamModel() + + def __init__(self, **data): + super().__init__(**data) + openai.api_key = "" + openai.api_base = "http://0.0.0.0:9729/v1" + + @classmethod + async def create(cls, config_data): + return LlamaOpenAIClient(**config_data) + + def get_model_name(self) -> str: + return self.model_name + + def get_model_param(self) -> LlamaParamModel: + return self.params + + def completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + """ + response = chatcompletion_with_backoff( + model=self.model_name, + messages=[ + {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} + ], + timeout=1000, + # temperature=self.params.temperature, + # max_tokens=self.params.max_tokens, + # top_p=self.params.top_p, + # frequency_penalty=self.params.frequency_penalty, + # presence_penalty=self.params.presence_penalty, + # stop=["<|im_end|>", "<|endoftext|>"], + **kwargs + ) + + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + + async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + + """ + response = await async_chatcompletion_with_backoff( + model=self.model_name, + messages=[ + {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} + ], + timeout=1000, + #temperature=0.2, + #max_tokens=4096, + #top_p=0.9, + #frequency_penalty=self.params.frequency_penalty, + #presence_penalty=self.params.presence_penalty, + # stop=["<|im_end|>", "<|endoftext|>"], + **kwargs + ) + + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + + def chat_completion(self, message: List[dict]) -> ChatCompletion: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + timeout=1000, + messages=message, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + ) + return ChatCompletion( + state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0), + ) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def stream_chat_completion(self, message: List[dict], **kwargs): + """ + Stream output chat completion for OpenAI GPT API. + + :param message: The message (scratchpad) to use for completion. Usually contains json of role and content. + :type message: List[dict] + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + # response = openai.ChatCompletion.create( + # engine=self.get_model_name(), # GPT-4 + # messages=message, + # timeout=1000, + # **kwargs, + # ) + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + stream=True, + **kwargs + ) + role = next(response).choices[0].delta["role"] + messages = [] + ## TODO: Calculate prompt_token and for stream mode + for resp in response: + messages.append(resp.choices[0].delta.get("content", "")) + yield ChatCompletion( + state="success", + role=role, + content=messages[-1], + prompt_token=0, + completion_token=0, + ) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def function_chat_completion( + self, + message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict], + ) -> ChatCompletionWithHistory: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :param function_map: The function map to use for completion. + :type function_map: Dict[str, Callable] + :param function_schema: The function schema to use for completion. + :type function_schema: List[Dict] + :return: ChatCompletionWithHistory object. + :rtype: ChatCompletionWithHistory + """ + assert len(function_schema) == len(function_map) + try: + # response = openai.ChatCompletion.create( + # engine=self.get_model_name(), # GPT-4 + # messages=message, + # functions=function_schema, + # timeout=1000, + # ) + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + functions=function_schema, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + ) + response_message = response.choices[0]["message"] + + if response_message.get("function_call"): + function_name = response_message["function_call"]["name"] + fuction_to_call = function_map[function_name] + function_args = json.loads( + response_message["function_call"]["arguments"] + ) + function_response = fuction_to_call(**function_args) + + # Postprocess function response + if isinstance(function_response, str): + plugin_cost = 0 + plugin_token = 0 + elif isinstance(function_response, AgentOutput): + plugin_cost = function_response.cost + plugin_token = function_response.token_usage + function_response = function_response.output + else: + raise Exception( + "Invalid tool response type. Must be on of [AgentOutput, str]" + ) + + message.append(dict(response_message)) + message.append( + { + "role": "function", + "name": function_name, + "content": function_response, + } + ) + second_response = openai.ChatCompletion.create( + model=self.get_model_name(), + messages=message, + ) + message.append(dict(second_response.choices[0].message)) + return ChatCompletionWithHistory( + state="success", + role=second_response.choices[0].message["role"], + content=second_response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0) + + second_response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get( + "completion_tokens", 0 + ) + + second_response.get("usage", {}).get("completion_tokens", 0), + message_scratchpad=message, + plugin_cost=plugin_cost, + plugin_token=plugin_token, + ) + else: + message.append(dict(response_message)) + return ChatCompletionWithHistory( + state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get( + "completion_tokens", 0 + ), + message_scratchpad=message, + ) + + except Exception as exception: + print("Exception:", exception) + return ChatCompletionWithHistory(state="error", content=str(exception)) + + def function_chat_stream_completion( + self, + message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict], + ) -> ChatCompletionWithHistory: + assert len(function_schema) == len(function_map) + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.get_model_name(), + messages=message, + functions=function_schema, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + stream=True, + ) + tmp = next(response) + role = tmp.choices[0].delta["role"] + _type = ( + "function_call" + if tmp.choices[0].delta["content"] is None + else "content" + ) + if _type == "function_call": + name = tmp.choices[0].delta["function_call"]["name"] + yield _type, ChatCompletionWithHistory( + state="success", + role=role, + content="{" + f'"name":"{name}", "arguments":', + message_scratchpad=message, + ) + for resp in response: + # print(resp) + content = resp.choices[0].delta.get(_type, "") + if isinstance(content, dict): + content = content["arguments"] + yield _type, ChatCompletionWithHistory( + state="success", + role=role, + content=content, + message_scratchpad=message, + ) + + # result = ''.join(messages) + # if _type == "function_call": + # result = json.loads(result) + # function_name = result["name"] + # fuction_to_call = function_map[function_name] + # function_args = result["arguments"] + # function_response = fuction_to_call(**function_args) + # + # # Postprocess function response + # if isinstance(function_response, AgentOutput): + # function_response = function_response.output + # message.append({"role": "function", + # "name": function_name, + # "content": function_response}) + # second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema) + # message.append(dict(second_response.choices[0].message)) + + except Exception as e: + logger.error(f"Failed to get response {str(e)}", exc_info=True) + raise e diff --git a/src/infiagent/llm/client/openai.py b/src/infiagent/llm/client/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..851a9ee671204532ac36ebc84adce7eb9cc2f79b --- /dev/null +++ b/src/infiagent/llm/client/openai.py @@ -0,0 +1,306 @@ +import json +import os +from abc import ABC +from typing import Callable, List + +import openai + +from ..base_llm import BaseLLM +from ...schemas import * + + +class OpenAIGPTClient(BaseLLM, ABC): + """ + Wrapper class for OpenAI GPT API collections. + + :param model_name: The name of the model to use. + :type model_name: str + :param params: The parameters for the model. + :type params: OpenAIParamModel + """ + model_name: str + params: OpenAIParamModel = OpenAIParamModel() + + def __init__(self, **data): + super().__init__(**data) + openai.api_key = os.environ.get("OPENAI_API_KEY", "") + + @classmethod + async def create(cls, config_data): + return OpenAIGPTClient(**config_data) + + def get_model_name(self) -> str: + return self.model_name + + def get_model_param(self) -> OpenAIParamModel: + return self.params + + def completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + + """ + try: + #TODO any full parameters support + response = openai.ChatCompletion.create( + # n=self.params['n'], + engine=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=self.params['temperature'], + max_tokens=self.params['max_tokens'], + top_p=self.params['top_p'], + # frequency_penalty=self.params.frequency_penalty, + # presence_penalty=self.params.presence_penalty, + **kwargs + ) + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + except Exception as exception: + print("Exception:", exception) + return BaseCompletion(state="error", content=exception) + + async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Async Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + + """ + try: + response = await openai.ChatCompletion.acreate( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=self.params['temperature'], + max_tokens=self.params['max_tokens'], + top_p=self.params['top_p'], + # frequency_penalty=self.params.frequency_penalty, + # presence_penalty=self.params.presence_penalty, + **kwargs + ) + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + except Exception as exception: + print("Exception:", exception) + return BaseCompletion(state="error", content=exception) + + + def chat_completion(self, message: List[dict]) -> ChatCompletion: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + ) + return ChatCompletion(state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def stream_chat_completion(self, message: List[dict], **kwargs): + """ + Stream output chat completion for OpenAI GPT API. + + :param message: The message (scratchpad) to use for completion. Usually contains json of role and content. + :type message: List[dict] + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + stream=True, + **kwargs + ) + role = next(response).choices[0].delta["role"] + messages = [] + ## TODO: Calculate prompt_token and for stream mode + for resp in response: + messages.append(resp.choices[0].delta.get("content", "")) + yield ChatCompletion(state="success", + role=role, + content=messages[-1], + prompt_token=0, + completion_token=0) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def function_chat_completion(self, message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict]) -> ChatCompletionWithHistory: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :param function_map: The function map to use for completion. + :type function_map: Dict[str, Callable] + :param function_schema: The function schema to use for completion. + :type function_schema: List[Dict] + :return: ChatCompletionWithHistory object. + :rtype: ChatCompletionWithHistory + """ + assert len(function_schema) == len(function_map) + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + functions=function_schema, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + ) + response_message = response.choices[0]["message"] + + if response_message.get("function_call"): + function_name = response_message["function_call"]["name"] + fuction_to_call = function_map[function_name] + function_args = json.loads(response_message["function_call"]["arguments"]) + function_response = fuction_to_call(**function_args) + + # Postprocess function response + if isinstance(function_response, str): + plugin_cost = 0 + plugin_token = 0 + elif isinstance(function_response, AgentOutput): + plugin_cost = function_response.cost + plugin_token = function_response.token_usage + function_response = function_response.output + else: + raise Exception("Invalid tool response type. Must be on of [AgentOutput, str]") + + message.append(dict(response_message)) + message.append({"role": "function", + "name": function_name, + "content": function_response}) + second_response = openai.ChatCompletion.create( + model=self.model_name, + messages=message, + ) + message.append(dict(second_response.choices[0].message)) + return ChatCompletionWithHistory(state="success", + role=second_response.choices[0].message["role"], + content=second_response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0) + + second_response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0) + + second_response.get("usage", {}).get("completion_tokens", 0), + message_scratchpad=message, + plugin_cost=plugin_cost, + plugin_token=plugin_token, + ) + else: + message.append(dict(response_message)) + return ChatCompletionWithHistory(state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0), + message_scratchpad=message) + + except Exception as exception: + print("Exception:", exception) + return ChatCompletionWithHistory(state="error", content=str(exception)) + + def function_chat_stream_completion(self, message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict]) -> ChatCompletionWithHistory: + assert len(function_schema) == len(function_map) + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + functions=function_schema, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + stream=True + ) + tmp = next(response) + role = tmp.choices[0].delta["role"] + _type = "function_call" if tmp.choices[0].delta["content"] is None else "content" + if _type == "function_call": + name = tmp.choices[0].delta['function_call']['name'] + yield _type, ChatCompletionWithHistory(state="success", role=role, + content="{" + f'"name":"{name}", "arguments":', + message_scratchpad=message) + for resp in response: + # print(resp) + content = resp.choices[0].delta.get(_type, "") + if isinstance(content, dict): + content = content['arguments'] + yield _type, ChatCompletionWithHistory(state="success", + role=role, + content=content, + message_scratchpad=message) + + # result = ''.join(messages) + # if _type == "function_call": + # result = json.loads(result) + # function_name = result["name"] + # fuction_to_call = function_map[function_name] + # function_args = result["arguments"] + # function_response = fuction_to_call(**function_args) + # + # # Postprocess function response + # if isinstance(function_response, AgentOutput): + # function_response = function_response.output + # message.append({"role": "function", + # "name": function_name, + # "content": function_response}) + # second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema) + # message.append(dict(second_response.choices[0].message)) + + + except Exception as exception: + raise exception + print("Exception:", exception) + return ChatCompletion(state="error", content=str(exception)) diff --git a/src/infiagent/llm/client/opt.py b/src/infiagent/llm/client/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..bb21fb5f2cbe58f844afdd3beae62eeeddf5e8e8 --- /dev/null +++ b/src/infiagent/llm/client/opt.py @@ -0,0 +1,373 @@ +import json +import logging +import os +from abc import ABC +from typing import Callable, List + +import openai +from tenacity import ( # for exponential backoff + before_sleep_log, + retry, + stop_after_attempt, + wait_random_exponential, +) + +from ..base_llm import BaseLLM +from ...schemas import * + +logger = logging.getLogger(__name__) + +MAX_PROMPT_LENGTH = 7000 + + +@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING)) +def chatcompletion_with_backoff(**kwargs): + return openai.ChatCompletion.create(**kwargs) + + +@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING)) +async def async_chatcompletion_with_backoff(**kwargs): + async def _internal_coroutine(): + return await openai.ChatCompletion.acreate(**kwargs) + + return await _internal_coroutine() + + +class OptOpenAIClient(BaseLLM, ABC): + """ + Wrapper class for OpenAI GPT API collections. + + :param model_name: The name of the model to use. + :type model_name: str + :param params: The parameters for the model. + :type params: OptParamModel + """ + + model_name: str + params: OptParamModel = OptParamModel() + + def __init__(self, **data): + super().__init__(**data) + openai.api_key = "EMPTY" + openai.api_base = "http://localhost:8000/v1" + + @classmethod + async def create(cls, config_data): + return OptOpenAIClient(**config_data) + + def get_model_name(self) -> str: + return self.model_name + + def get_model_param(self) -> OptParamModel: + return self.params + + def completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + + """ + + response = chatcompletion_with_backoff( + model=self.model_name, + # engine=self.get_model_name(), # GPT-4 + messages=[ + {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} + ], + timeout=1000, + **kwargs + ) + + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + + async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion: + """ + Completion method for OpenAI GPT API. + + :param prompt: The prompt to use for completion. + :type prompt: str + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: BaseCompletion object. + :rtype: BaseCompletion + + """ + response = await async_chatcompletion_with_backoff( + # engine=self.get_model_name(), # GPT-4 + model=self.model_name, + messages=[ + {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} + ], + timeout=1000, + **kwargs + ) + + return BaseCompletion(state="success", + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0)) + + def chat_completion(self, message: List[dict]) -> ChatCompletion: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + # response = openai.ChatCompletion.create( + # engine=self.get_model_name(), # GPT-4 + # messages=message, + # timeout=1000, + # ) + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + ) + return ChatCompletion( + state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get("completion_tokens", 0), + ) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def stream_chat_completion(self, message: List[dict], **kwargs): + """ + Stream output chat completion for OpenAI GPT API. + + :param message: The message (scratchpad) to use for completion. Usually contains json of role and content. + :type message: List[dict] + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :return: ChatCompletion object. + :rtype: ChatCompletion + """ + try: + # response = openai.ChatCompletion.create( + # engine=self.get_model_name(), # GPT-4 + # messages=message, + # timeout=1000, + # **kwargs, + # ) + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + stream=True, + **kwargs + ) + role = next(response).choices[0].delta["role"] + messages = [] + ## TODO: Calculate prompt_token and for stream mode + for resp in response: + messages.append(resp.choices[0].delta.get("content", "")) + yield ChatCompletion( + state="success", + role=role, + content=messages[-1], + prompt_token=0, + completion_token=0, + ) + except Exception as exception: + print("Exception:", exception) + return ChatCompletion(state="error", content=exception) + + def function_chat_completion( + self, + message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict], + ) -> ChatCompletionWithHistory: + """ + Chat completion method for OpenAI GPT API. + + :param message: The message to use for completion. + :type message: List[dict] + :param function_map: The function map to use for completion. + :type function_map: Dict[str, Callable] + :param function_schema: The function schema to use for completion. + :type function_schema: List[Dict] + :return: ChatCompletionWithHistory object. + :rtype: ChatCompletionWithHistory + """ + assert len(function_schema) == len(function_map) + try: + # response = openai.ChatCompletion.create( + # engine=self.get_model_name(), # GPT-4 + # messages=message, + # functions=function_schema, + # timeout=1000, + # ) + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.model_name, + messages=message, + functions=function_schema, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + ) + response_message = response.choices[0]["message"] + + if response_message.get("function_call"): + function_name = response_message["function_call"]["name"] + fuction_to_call = function_map[function_name] + function_args = json.loads( + response_message["function_call"]["arguments"] + ) + function_response = fuction_to_call(**function_args) + + # Postprocess function response + if isinstance(function_response, str): + plugin_cost = 0 + plugin_token = 0 + elif isinstance(function_response, AgentOutput): + plugin_cost = function_response.cost + plugin_token = function_response.token_usage + function_response = function_response.output + else: + raise Exception( + "Invalid tool response type. Must be on of [AgentOutput, str]" + ) + + message.append(dict(response_message)) + message.append( + { + "role": "function", + "name": function_name, + "content": function_response, + } + ) + second_response = openai.ChatCompletion.create( + model=self.get_model_name(), + messages=message, + ) + message.append(dict(second_response.choices[0].message)) + return ChatCompletionWithHistory( + state="success", + role=second_response.choices[0].message["role"], + content=second_response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0) + + second_response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get( + "completion_tokens", 0 + ) + + second_response.get("usage", {}).get("completion_tokens", 0), + message_scratchpad=message, + plugin_cost=plugin_cost, + plugin_token=plugin_token, + ) + else: + message.append(dict(response_message)) + return ChatCompletionWithHistory( + state="success", + role=response.choices[0].message["role"], + content=response.choices[0].message["content"], + prompt_token=response.get("usage", {}).get("prompt_tokens", 0), + completion_token=response.get("usage", {}).get( + "completion_tokens", 0 + ), + message_scratchpad=message, + ) + + except Exception as exception: + print("Exception:", exception) + return ChatCompletionWithHistory(state="error", content=str(exception)) + + def function_chat_stream_completion( + self, + message: List[dict], + function_map: Dict[str, Callable], + function_schema: List[Dict], + ) -> ChatCompletionWithHistory: + assert len(function_schema) == len(function_map) + try: + response = openai.ChatCompletion.create( + n=self.params.n, + model=self.get_model_name(), + messages=message, + functions=function_schema, + temperature=self.params.temperature, + max_tokens=self.params.max_tokens, + top_p=self.params.top_p, + frequency_penalty=self.params.frequency_penalty, + presence_penalty=self.params.presence_penalty, + stream=True, + ) + tmp = next(response) + role = tmp.choices[0].delta["role"] + _type = ( + "function_call" + if tmp.choices[0].delta["content"] is None + else "content" + ) + if _type == "function_call": + name = tmp.choices[0].delta["function_call"]["name"] + yield _type, ChatCompletionWithHistory( + state="success", + role=role, + content="{" + f'"name":"{name}", "arguments":', + message_scratchpad=message, + ) + for resp in response: + # print(resp) + content = resp.choices[0].delta.get(_type, "") + if isinstance(content, dict): + content = content["arguments"] + yield _type, ChatCompletionWithHistory( + state="success", + role=role, + content=content, + message_scratchpad=message, + ) + + # result = ''.join(messages) + # if _type == "function_call": + # result = json.loads(result) + # function_name = result["name"] + # fuction_to_call = function_map[function_name] + # function_args = result["arguments"] + # function_response = fuction_to_call(**function_args) + # + # # Postprocess function response + # if isinstance(function_response, AgentOutput): + # function_response = function_response.output + # message.append({"role": "function", + # "name": function_name, + # "content": function_response}) + # second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema) + # message.append(dict(second_response.choices[0].message)) + + except Exception as e: + logger.error(f"Failed to get response {str(e)}", exc_info=True) + raise e diff --git a/src/infiagent/prompt/__init__.py b/src/infiagent/prompt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7f12a7c40c6f088db434f7dee3d2808bf443bd --- /dev/null +++ b/src/infiagent/prompt/__init__.py @@ -0,0 +1,3 @@ +from .prompt_template import * +from .simple_react_prompt import SimpleReactPrompt +from .zero_shot_react_prompt import ZeroShotReactPrompt diff --git a/src/infiagent/prompt/prompt_template.py b/src/infiagent/prompt/prompt_template.py new file mode 100644 index 0000000000000000000000000000000000000000..f64b0444fe0f9f9e1ad0265f6635eb713e366d7d --- /dev/null +++ b/src/infiagent/prompt/prompt_template.py @@ -0,0 +1,83 @@ +"""Prompt schema definition.""" +from abc import ABC, abstractmethod +from string import Formatter +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Extra, root_validator + +from ..exceptions.exceptions import InputErrorException +from ..schemas import AgentAction, AgentObservation, BaseAgentResponse + +OBSERVATION_KEY = "Observation" +THOUGHT_KEY = "Thought" +FINAL_ANSWER_KEY = "FinalAnswer" + +DEFAULT_OBSERVATION = "Observation:" +DEFAULT_THOUGHT = "Thought:" +DEFAULT_FINAL_ANSWER = "Final Answer:" + + +class PromptTemplate(BaseModel, ABC): + _input_variables: List[str] + _template: str + _keywords: Dict[str, str] + _name: str + _validate_template: bool + _skip_on_failure: bool + + class Config: + extra = Extra.forbid + + @property + def input_variables(self) -> List[str]: + return self._input_variables + + @property + def template(self) -> str: + return self._template + + @property + def keywords(self) -> Dict[str, str]: + return self._keywords + + @property + def name(self) -> str: + return self._name + + def format(self, **kwargs): + if not set(self._input_variables).issubset(kwargs.keys()): + missing_keys = set(self._input_variables) - kwargs.keys() + raise InputErrorException(f"Missing keys in prompt template: {', '.join(missing_keys)}") + + filtered_kwargs = {key: kwargs[key] for key in self._input_variables if key in kwargs} + + return self._template.format(**filtered_kwargs) + + def construct_scratchpad(self, intermediate_steps: List[BaseAgentResponse]) -> str: + """Construct the scratchpad that lets the agent continue its thought process.""" + thoughts = "" + + for agent_response in intermediate_steps: + if isinstance(agent_response, AgentAction): + # for agent action, use thought + thoughts += agent_response.raw_output + elif isinstance(agent_response, AgentObservation): + # for agent observation use observation + thoughts += f"\n{self.keywords.get(OBSERVATION_KEY, DEFAULT_OBSERVATION)}\n" \ + f"{agent_response.formatted_output}\n\n" \ + f"{self.keywords.get(THOUGHT_KEY, DEFAULT_THOUGHT)}\n" + + return thoughts + + @classmethod + @root_validator(skip_on_failure=True) + def template_is_valid(cls, values: Dict) -> Dict: + """Check that template and input variables are consistent.""" + if values["validate_template"]: + try: + dummy_input = {var: "" for var in values["input_variables"]} + Formatter().format(values["template"], **dummy_input) + except KeyError as e: + raise InputErrorException("Invalid prompt schema; check for mismatched or missing input parameters. ")\ + from e + return values diff --git a/src/infiagent/prompt/simple_react_prompt.py b/src/infiagent/prompt/simple_react_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..a36538e8644f33ab2bd949abc33b54a4133ed38e --- /dev/null +++ b/src/infiagent/prompt/simple_react_prompt.py @@ -0,0 +1,17 @@ +from ..prompt import FINAL_ANSWER_KEY, OBSERVATION_KEY, THOUGHT_KEY, PromptTemplate + + +class SimpleReactPrompt(PromptTemplate): + _input_variables = ["instruction", "agent_scratchpad"] + _template = "{instruction} \n{agent_scratchpad}" + _keywords = { + OBSERVATION_KEY: "[EOS]Observation:", + THOUGHT_KEY: "[SEP]", + FINAL_ANSWER_KEY: "[END]" + } + _name = 'SimpleReactPrompt' + _validate_template = True + _skip_on_failure = True + + def __init__(self, **data): + super().__init__(**data) diff --git a/src/infiagent/prompt/zero_shot_react_prompt.py b/src/infiagent/prompt/zero_shot_react_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6b3c5af8caf74299bc587dd4c1d2d14f7ba4f6 --- /dev/null +++ b/src/infiagent/prompt/zero_shot_react_prompt.py @@ -0,0 +1,36 @@ +from ..prompt import PromptTemplate, OBSERVATION_KEY, THOUGHT_KEY, FINAL_ANSWER_KEY, DEFAULT_OBSERVATION, \ + DEFAULT_THOUGHT, DEFAULT_FINAL_ANSWER + + +class ZeroShotReactPrompt(PromptTemplate): + _input_variables = ["instruction", "agent_scratchpad", "tool_names", "tool_description"] + _template = ( + "Answer the following questions as best you can." + "You have access to the following tools:\n" + "{tool_description}.\n" + "Use the following format:\n\n" + "Question: the input question you must answer\n" + "Thought: you should always think about what to do\n\n" + "Action: the action to take, should be one of [{tool_names}]\n\n" + "Action Input:\n```python\n[the input to the action]\n```\n" + "Observation: the result of the action\n\n" + "... (this Thought/Action/Action Input/Observation can repeat N times)\n" + "Thought: I now know the final answer\n" + "Final Answer: the final answer to the original input question\n" + "If you have any files outputted write them to \"./\"\n" + "Do not use things like plot.show() as it will not work instead write them out \"./\"\n" + "Begin!\n\n" + "Question: {instruction}\nThought:\n" + "{agent_scratchpad}\n" + ) + _keywords = { + OBSERVATION_KEY: DEFAULT_OBSERVATION, + THOUGHT_KEY: DEFAULT_THOUGHT, + FINAL_ANSWER_KEY: DEFAULT_FINAL_ANSWER + } + _name = 'ZeroShotReactPrompt' + _validate_template = True + _skip_on_failure = True + + def __init__(self, **data): + super().__init__(**data) diff --git a/src/infiagent/schemas/__init__.py b/src/infiagent/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2534538fd5965a9207092d9a91b71678455d604 --- /dev/null +++ b/src/infiagent/schemas/__init__.py @@ -0,0 +1,5 @@ +from .base_models import * +from .complete_models import * +from .sandbox_models import * +from .agent_models import * +from .llm_models import * \ No newline at end of file diff --git a/src/infiagent/schemas/agent_models.py b/src/infiagent/schemas/agent_models.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c296b1f042086dee09bca5aa769b8eaaa5eb6c --- /dev/null +++ b/src/infiagent/schemas/agent_models.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import abc +from dataclasses import dataclass, field +from enum import Enum +from typing import List, NamedTuple, Optional, Union + +from pydantic import BaseModel + +from ..schemas.sandbox_models import * + + +@dataclass +class BaseAgentResponse: + """Base Agent step result, contains formatted output string.""" + formatted_output: str + raw_output: str + + +@dataclass +class AgentAction(BaseAgentResponse): + """ + Agent's action to take. + """ + tool: str + tool_input: Union[str, dict] + + +@dataclass +class AgentObservation(BaseAgentResponse): + """ + Agent's action to take. + """ + tool: str + + +@dataclass +class AgentFinish(BaseAgentResponse): + """Agent's return value when finishing execution.""" + pass + + +class AgentType(Enum): + """ + Enumerated type for agent types. + """ + openai = "openai" + react = "react" + rewoo = "rewoo" + vanilla = "vanilla" + openai_memory = "openai_memory" + + @staticmethod + def get_agent_class(_type: AgentType): + """ + Get agent class from agent type. + :param _type: agent type + :return: agent class + """ + if _type == AgentType.react: + from ..agent.react import ReactAgent + return ReactAgent + else: + raise ValueError(f"Unknown agent type: {_type}") + + +class AgentOutput(BaseModel): + """ + Pydantic model for agent output. + """ + output: str + cost: float + token_usage: int + + +@dataclass +class AgentRequest: + sandbox_id: Optional[str] = None + messages: List[Message] = field(default_factory=list) + input_files: List[MediaFile] = field(default_factory=list) + sandbox_status: Optional[SandboxStatus] = None + is_cn: bool = False + + + +@dataclass +class AgentResponse: + output_text: str + raw_output_text: str + output_files: List[MediaFile] = field(default_factory=list) + sandbox_id: Optional[str] = None + sandbox_status: Optional[SandboxStatus] = None + turn_level_prompt: Optional[List[str]] = None + turn_level_response: Optional[List[str]] = None + + +class RoleType(Enum): + User = 0 + System = 1 + Agent = 2 + + @classmethod + def _missing_(cls, name): + # If the input is a string, perform case-insensitive matching + if isinstance(name, str): + for member in cls: + if member.name.lower() == name.lower(): + return member + return super()._missing_(name) + + +@dataclass +class Message(abc.ABC): + role: RoleType + content: str + raw_content: str = "" + + @staticmethod + def parse_from_dict(data): + data['role'] = RoleType(data['role']) + # Add a check for raw_content in legacy data + if 'raw_content' not in data: + data['raw_content'] = "" + return Message(**data) + + def to_dict(self): + role_value = self.role.value if isinstance(self.role, RoleType) else self.role + return { + "role": role_value, + "content": self.content, # Fixed the missing comma here + "raw_content": self.raw_content + } + + +@dataclass +class MediaFile: + file_name: Optional[str] = None + file_content: Optional[bytes] = None + tos_path: Optional[str] = None + sandbox_path: Optional[str] = None + + def __dict__(self): + return { + 'file_name': self.file_name if self.file_name is not None else "", + 'file_content': self.file_content if self.file_content is not None else "", + 'tos_path': self.tos_path if self.tos_path is not None else "", + 'sandbox_path': self.sandbox_path if self.sandbox_path is not None else "", + } diff --git a/src/infiagent/schemas/base_models.py b/src/infiagent/schemas/base_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/infiagent/schemas/complete_models.py b/src/infiagent/schemas/complete_models.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a7a1c32da5215b73920aab89a95bc7e374fd72 --- /dev/null +++ b/src/infiagent/schemas/complete_models.py @@ -0,0 +1,236 @@ +# coding: utf-8 +from datetime import datetime +from time import time +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel + +from ..schemas.agent_models import Message +from ..utils.file_utils import get_file_name_and_path + +# Definitions for inputs and outputs schema for /complete api + +DEFAULT_TOP_P = 0.7 +DEFAULT_TEMPERATURE = 1.0 +DEFAULT_STREAM = False + +FINISH_STATUS = "FINISH" +FAILED_STATUS = "FAILED" +PROCESSING_STATUS = "PROCESSING" +ASSISTANT = "assistant" + + +# Main Input Model +class ChatCompleteRequest(BaseModel): + chat_id: str # unique chat id for given chat + code_interpreter: Optional[dict] = {} + messages: List[dict] = [] # chat message + model: str = "AZURE_OPEN_AI" # model name map to LLM conf + user: str + max_tokens: Optional[int] = None + message_conf: Optional[dict] = {} + n: Optional[int] = None + plugins: Optional[List[str]] = None + seed_conf: Optional[dict] = {} + stream: Optional[bool] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + webgpt: Optional[Dict[str, Any]] = None + webgpt_network: Optional[bool] = None + + +class MessageConf(BaseModel): + top_p: float = DEFAULT_TOP_P + temperature: float = DEFAULT_TEMPERATURE + top_k: Optional[int] = None + time_cost: int + code_interpreter: dict + gpt_engine_conf: dict + stream: bool + + +class Delta(BaseModel): + role: str + content: str + sid: str + status: str + end_turn: bool + parent_id: str + children_ids: Optional[Union[List[str], None]] + err_msg: str + creator: str + updater: str + ctime: str + utime: str + message_conf: MessageConf + + def json(self, *args, **kwargs): + serialized_data = super().json(*args, **kwargs) + return serialized_data.replace("+00:00", "Z") + + +class Choice(BaseModel): + index: int + delta: Delta + finish_reason: str + + +class ChatCompleteResponse(BaseModel): + id: str + created: int + choices: List[Choice] + + +def chat_request_to_message_conf(chat_request: ChatCompleteRequest) -> MessageConf: + input_files = {} + + if chat_request.code_interpreter and "tos_key" in chat_request.code_interpreter: + input_file = chat_request.code_interpreter["tos_key"] + file_name, tos_path = get_file_name_and_path(input_file) + input_files = {"tos_key": file_name} + + return MessageConf( + top_p=chat_request.top_p if chat_request.top_p is not None else DEFAULT_TOP_P, + temperature=chat_request.temperature if chat_request.temperature is not None else DEFAULT_TEMPERATURE, + code_interpreter=input_files, + time_cost=0, + gpt_engine_conf={}, + stream=chat_request.stream if chat_request.stream is not None else DEFAULT_STREAM + ) + + +def chat_request_to_deltas(chat_request: ChatCompleteRequest) -> List[Delta]: + deltas = [] + message_conf = chat_request_to_message_conf(chat_request) + + for message in chat_request.messages: + delta = Delta( + role=ASSISTANT, + content=message["content"], + sid="", + status="FINISH", + end_turn=False, + parent_id="", + children_ids=None, + err_msg="", + creator=chat_request.user, + updater=chat_request.user, + ctime=current_utc_time_as_str(), + utime=current_utc_time_as_str(), + message_conf=message_conf + ) + deltas.append(delta) + + return deltas + + +def chat_request_to_choices(chat_request: ChatCompleteRequest) -> List[Choice]: + deltas = chat_request_to_deltas(chat_request) + choices = [] + + for index, delta in enumerate(deltas): + choice = Choice( + index=index, + delta=delta, + finish_reason="stop" + ) + choices.append(choice) + + return choices + + +def chat_request_to_response(chat_request: ChatCompleteRequest) -> ChatCompleteResponse: + return ChatCompleteResponse( + id=chat_request.chat_id, + created=int(time()), + choices=chat_request_to_choices(chat_request) + ) + + +def update_chat_response_with_message(chat_response: ChatCompleteResponse, + message: Message, + status: Union[str, None] = None) -> ChatCompleteResponse: + # Get the last Delta (if exists) + last_delta = chat_response.choices[-1].delta if chat_response.choices else None + updated_delta = Delta( + role=ASSISTANT, # map with front end + content=message.content, + sid=last_delta.sid if last_delta else "", + status=status if status is not None else FINISH_STATUS, + end_turn=False, + parent_id=last_delta.parent_id if last_delta else "", + children_ids=last_delta.children_ids if last_delta else None, + err_msg="", + creator=last_delta.creator if last_delta else None, + updater=last_delta.updater if last_delta else None, + ctime=last_delta.ctime if last_delta else current_utc_time_as_str(), + utime=current_utc_time_as_str(), + message_conf=MessageConf( + top_p=last_delta.message_conf.top_p if last_delta and last_delta.message_conf.top_p else DEFAULT_TOP_P, + temperature=last_delta.message_conf.temperature if last_delta and last_delta.message_conf.temperature else + DEFAULT_TEMPERATURE, + code_interpreter=last_delta.message_conf.code_interpreter + if last_delta and last_delta.message_conf.code_interpreter else {}, + time_cost=0, + gpt_engine_conf={}, + stream=last_delta.message_conf.stream if last_delta and last_delta.message_conf.stream is not None else + False + ) + ) + + updated_choice = Choice( + index=0, # Since it's the only choice in the list + delta=updated_delta, + finish_reason="stop" + ) + + # Update the ChatCompleteResponse to contain only the new Choice + chat_response.choices = [updated_choice] + return chat_response + + +def current_utc_time_as_str() -> str: + return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') + + +def create_empty_response(): + # Dummy instance for Delta + delta = Delta( + role=ASSISTANT, + content="", + sid="", + status="", + end_turn=False, + parent_id="", + children_ids=None, + err_msg="", + creator="", + updater="", + ctime="", + utime="", + message_conf=MessageConf( + top_p=0.0, + temperature=0, + time_cost=0, + code_interpreter={}, + gpt_engine_conf={}, + stream=False + ) + ) + + # Dummy instance for Choice + choice = Choice( + index=0, + delta=delta, + finish_reason="" + ) + + # Dummy instance for ChatCompleteResponse + response = ChatCompleteResponse( + id="", + created=0, + choices=[choice] + ) + return response + diff --git a/src/infiagent/schemas/llm_models.py b/src/infiagent/schemas/llm_models.py new file mode 100644 index 0000000000000000000000000000000000000000..ff886985f79330f4c18e018c6d0c8b34baf1b674 --- /dev/null +++ b/src/infiagent/schemas/llm_models.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, NamedTuple, Union + +from pydantic import BaseModel + +try: + import torch +except ImportError: + pass + + + +class BaseCompletion(BaseModel): + state: str # "success" or "error" + content: str + prompt_token: int = 0 + completion_token: int = 0 + + def to_dict(self): + return dict( + state=self.state, + content=self.content, + prompt_token=self.prompt_token, + completion_token=self.completion_token, + ) + + +class ChatCompletion(BaseCompletion): + role: str = "assistant" # "system" or "user" or "assistant" + + +class ChatCompletionWithHistory(ChatCompletion): + """Used for function call API""" + message_scratchpad: List[Dict] = [] + plugin_cost: float = 0.0 + plugin_token: float = 0.0 + + +class BaseParamModel(BaseModel): + def __eq__(self, other): + return self.dict() == other.dict() + + +class OpenAIParamModel(BaseModel): + """ + OpenAI API parameters + """ + max_tokens: int = 2048 + temperature: float = 0.2 + top_p: float = 1.0 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + n: int = 1 + stop: list = [] + +class AzureOpenAIParamModel(BaseModel): + """ + AzureOpenAI API parameters + """ + max_tokens: int = 2048 + temperature: float = 0.2 + top_p: float = 1.0 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + n: int = 1 + stop: list = [] + +class LlamaParamModel(BaseModel): + """ + AzureOpenAI API parameters + """ + max_tokens: int = 4096 + temperature: float = 0.2 + top_p: float = 1.0 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + n: int = 1 + stop: list = [] + +class OptParamModel(BaseModel): + """ + AzureOpenAI API parameters + """ + max_tokens: int = 2048 + temperature: float = 0.2 + top_p: float = 1.0 + n: int = 1 + stop: list = [] diff --git a/src/infiagent/schemas/sandbox_models.py b/src/infiagent/schemas/sandbox_models.py new file mode 100644 index 0000000000000000000000000000000000000000..1452a626101733bbbe5228b54649e06956c3a6bb --- /dev/null +++ b/src/infiagent/schemas/sandbox_models.py @@ -0,0 +1,69 @@ +from enum import Enum +from typing import Any, List, Optional +from pydantic import BaseModel + +class SandboxStatus(Enum): + """ + Enumerated type for agent types. + """ + success = "success" + failed = "failed" + timeout = "timeout" + +class CodeOutput(BaseModel): + type: str + content: str + +class ReturnedFile(BaseModel): + download_link: str + name: str + path: str + +class CodeRunResult(BaseModel): + code_output_result: List[CodeOutput] + deleted_files: List[ReturnedFile] + new_generated_files: List[ReturnedFile] + +class CodeRunData(BaseModel): + is_partial: bool + result: CodeRunResult + + +class RunCodeOutput(BaseModel): + code: int + message: str + data: Optional[CodeRunData] + +class CreateSessionOutput(BaseModel): + code: int + message: str + + +class ErrorResponse(BaseModel): + code: int + message: str + data: Optional[Any] + + +class UploadOutput(BaseModel): + code: int + message: Optional[str] + data: Optional[str] + + +# Model for successful response (assuming it's a text file for this example) +class DownloadSuccessOutput(BaseModel): + file_name: str # this is not part of server response. We must fill this field in client. + content: str + + +class HeartbeatOutput(BaseModel): + code: Optional[int] + message: Optional[str] + + +class RefreshSandboxOutput(BaseModel): + code: Optional[int] + message: Optional[str] + + diff --git a/src/infiagent/services/__init__.py b/src/infiagent/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/infiagent/services/chat_complete_service.py b/src/infiagent/services/chat_complete_service.py new file mode 100644 index 0000000000000000000000000000000000000000..dc713a0be1823ddeed02c5126980f83e5abcc907 --- /dev/null +++ b/src/infiagent/services/chat_complete_service.py @@ -0,0 +1,196 @@ +import time +from io import BytesIO +from typing import Any, Dict, List, Union + +from fastapi import UploadFile +from starlette.datastructures import UploadFile as StarletteUploadFile +from werkzeug.datastructures import FileStorage + +from ..conversation_sessions import CodeInterpreterSession +from ..exceptions.exceptions import ( + DependencyException, + InputErrorException, + InternalErrorException, + ModelMaxIterationsException, +) +from ..schemas import Message, RoleType +from ..utils import get_logger +from ..tools import AsyncPythonSandBoxTool + +logger = get_logger() + + +async def predict( + prompt: str, + model_name: str, + config_path: str, + uploaded_files: Any, + **kwargs: Dict[str, Any]): + start_time = time.time() + + # create new session + session = await CodeInterpreterSession.create( + model_name=model_name, + config_path=config_path, + **kwargs + ) + + files = upload_files(uploaded_files, session.session_id) + logger.info(f"Session Creation Latency: {time.time() - start_time}") + + # upload file + if isinstance(files, str): + logger.info(f"Upload {files} as file path") + await session.upload_to_sandbox(files) + # upload list of file + elif isinstance(files, list): + for file in files: + if isinstance(file, str): + await session.upload_to_sandbox(file) + elif isinstance(file, UploadFile) or isinstance(file, StarletteUploadFile): + file_content = file.file.read() # get file content + file_like_object = BytesIO(file_content) + file_storage = FileStorage( + stream=file_like_object, + filename=file.filename, + content_type=file.content_type + ) + await session.upload_to_sandbox(file_storage) + else: + raise InputErrorException("The file type {} not supported, can't be uploaded".format(type(file))) + + # chat + try: + logger.info(f"Instruction message: {prompt}") + content = None + output_files = [] + user_messages = [Message(RoleType.User, prompt)] + async for response in session.chat(user_messages): + logger.info(f'Session Chat Response: {response}') + if content is None: + content = response.output_text + else: + content += response.output_text + + output_files.extend([output_file.__dict__() for output_file in response.output_files]) + + session.messages.append(Message(RoleType.Agent, content)) + AsyncPythonSandBoxTool.kill_kernels(session.session_id) + logger.info(f"Release python sandbox {session.session_id}") + logger.info(f"Total Latency: {time.time() - start_time}") + + return content + + except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \ + as e: + exception_messages = { + ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, " + "Please try another question.", + DependencyException: "Agent failed to process message due to dependency issue. You can try it later. " + "If it still happens, please contact oncall.", + InputErrorException: "Agent failed to process message due to value issue. If you believe all input are " + "correct, please contact oncall.", + InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.", + Exception: "Agent failed to process message due to unknown error, please contact oncall." + } + err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}") + logger.error(err_msg, exc_info=True) + + raise Exception(err_msg) + +import time +from typing import Union, List, Any, Dict +from io import BytesIO + +from fastapi import UploadFile +from starlette.datastructures import UploadFile as StarletteUploadFile + +from ..conversation_sessions import CodeInterpreterSession +from ..schemas import ( + Message, + RoleType +) +from werkzeug.datastructures import FileStorage + +from ..exceptions.exceptions import InputErrorException, DependencyException, InternalErrorException, \ + ModelMaxIterationsException + +from ..utils import get_logger, upload_files + +logger = get_logger() + + +async def predict( + prompt: str, + model_name: str, + uploaded_files: Any, + **kwargs: Dict[str, Any]): + start_time = time.time() + + # create new session + session = await CodeInterpreterSession.create( + model_name=model_name, + **kwargs + ) + + files = upload_files(uploaded_files, session.session_id) + logger.info(f"Session Creation Latency: {time.time() - start_time}") + + # upload file + if isinstance(files, str): + logger.info(f"Upload {files} as file path") + await session.upload_to_sandbox(files) + # upload list of file + elif isinstance(files, list): + for file in files: + if isinstance(file, str): + await session.upload_to_sandbox(file) + elif isinstance(file, UploadFile) or isinstance(file, StarletteUploadFile): + file_content = file.file.read() # get file content + file_like_object = BytesIO(file_content) + file_storage = FileStorage( + stream=file_like_object, + filename=file.filename, + content_type=file.content_type + ) + await session.upload_to_sandbox(file_storage) + else: + raise InputErrorException("The file type {} not supported, can't be uploaded".format(type(file))) + + # chat + try: + logger.info(f"Instruction message: {prompt}") + content = None + output_files = [] + user_messages = [Message(RoleType.User, prompt)] + + async for response in session.chat(user_messages): + logger.info(f'Session Chat Response: {response}') + if content is None: + content = response.output_text + else: + content += response.output_text + + output_files.extend([output_file.__dict__() for output_file in response.output_files]) + + session.messages.append(Message(RoleType.Agent, content)) + + logger.info(f"Total Latency: {time.time() - start_time}") + + return content + except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \ + as e: + exception_messages = { + ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, " + "Please try another question.", + DependencyException: "Agent failed to process message due to dependency issue. You can try it later. " + "If it still happens, please contact oncall.", + InputErrorException: "Agent failed to process message due to value issue. If you believe all input are " + "correct, please contact oncall.", + InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.", + Exception: "Agent failed to process message due to unknown error, please contact oncall." + } + err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}") + logger.error(err_msg, exc_info=True) + + raise Exception(err_msg) diff --git a/src/infiagent/services/chat_complete_sse_service.py b/src/infiagent/services/chat_complete_sse_service.py new file mode 100644 index 0000000000000000000000000000000000000000..f797978ea65b5f122a50d938a1efc732eb17f206 --- /dev/null +++ b/src/infiagent/services/chat_complete_sse_service.py @@ -0,0 +1,200 @@ +import time + +from ..activities.activity_helpers import DONE +from ..conversation_sessions import CodeInterpreterStreamSession +from ..db.conversation_dao import ConversationDAO +from ..db.conversation_do import ConversationDO, ConversationStatus +from ..exceptions.exceptions import ( + DependencyException, + InputErrorException, + InternalErrorException, + ModelMaxIterationsException, +) +from ..schemas import ( + FINISH_STATUS, + PROCESSING_STATUS, + ChatCompleteRequest, + MediaFile, + Message, + RoleType, + chat_request_to_response, + create_empty_response, + update_chat_response_with_message, +) +from ..utils import get_logger +from ..utils.file_utils import get_file_name_and_path + +EMPTY_RESPONSE = create_empty_response() + +logger = get_logger() + + +async def chat_event_generator(chat_request: ChatCompleteRequest): + """ + Init a chat session and start pushing response back. + This function is for SSE apis + """ + base_response = EMPTY_RESPONSE + session = None + start_time = time.time() + try: + base_response = chat_request_to_response(chat_request) + + logger.info("Start processing chat {} for {}, using model {}".format(chat_request.chat_id, chat_request.user, + chat_request.model)) + # init + # 1. get or create conversation in DB (not add current chat message yet) + # 2. init chat session + conversation = await get_or_create_conversation(chat_request) + + session = await CodeInterpreterStreamSession.create(model_name=chat_request.model, conversation=conversation) + + user_messages = [Message(RoleType.User, message["content"]) for message in chat_request.messages] + input_files = _get_input_file(chat_request) + + # yield chat response piece by piece + async for chat_response in process_chat_response(session, base_response, user_messages, input_files): + yield chat_response + except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \ + as e: + exception_messages = { + ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, " + "Please try another question.", + DependencyException: "Agent failed to process message due to dependency issue. You can try it later. " + "If it still happens, please contact oncall.", + InputErrorException: "Agent failed to process message due to value issue. If you believe all input are " + "correct, please contact oncall.", + InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.", + Exception: "Agent failed to process message due to unknown error, please contact oncall." + } + + err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}") + logger.error(err_msg, exc_info=True) + message = Message(role=RoleType.System, content=err_msg) + yield update_chat_response_with_message(base_response.copy(), message, status=FINISH_STATUS).dict() + if session and session.conversation: + session.conversation.status = ConversationStatus.FAILED + + yield DONE + + if session and session.conversation: + await ConversationDAO.update_conversation(session.conversation) + + +async def chat_event_response(chat_request: ChatCompleteRequest): + """ + Init a chat session and collect all response and return. + This function will collect all response and return + """ + + base_response = chat_request_to_response(chat_request) + + logger.info("Start processing chat {} for {}, using model {}".format(chat_request.chat_id, chat_request.user, + chat_request.model)) + + conversation = await get_or_create_conversation(chat_request) + + start_time = time.time() + session = await CodeInterpreterStreamSession.create(model_name=chat_request.model, conversation=conversation) + session_created_time = time.time() + + user_messages = [Message(RoleType.User, message["content"]) for message in chat_request.messages] + input_files = _get_input_file(chat_request) + + try: + content = None + output_files = [] + async for response in session.chat(user_messages, input_files): + if content is None: + content = response.output_text + else: + content += response.output_text + + output_files.extend([output_file.__dict__() for output_file in response.output_files]) + + message = Message(role=RoleType.Agent, content=content) + + return update_chat_response_with_message(base_response.copy(), message).dict() + except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \ + as e: + exception_messages = { + ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, " + "Please try another question.", + DependencyException: "Agent failed to process message due to dependency issue. You can try it later. " + "If it still happens, please contact oncall.", + InputErrorException: "Agent failed to process message due to value issue. If you believe all input are " + "correct, please contact oncall.", + InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.", + Exception: "Agent failed to process message due to unknown error, please contact oncall." + } + err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}") + logger.error(err_msg, exc_info=True) + if session and session.conversation: + session.conversation.status = ConversationStatus.FAILED + await ConversationDAO.update_conversation(session.conversation) + raise Exception(err_msg) from e + + +async def get_or_create_conversation(chat_request: ChatCompleteRequest): + """ + get conversation data from db or create new conversation based on input and store updated in DB + """ + chat_id = chat_request.chat_id + if not chat_request.chat_id: + raise InputErrorException("Invalid chat ID.") + + # Check if conversation exists in db + conversation_data = await ConversationDAO.get_conversation(chat_id) + + if not conversation_data: + logger.info("No existing conversation for {}. Creating new conversation.".format(chat_id)) + conversation_data = ConversationDO.create_conversation_from_request(chat_request) + await ConversationDAO.add_conversation(conversation_data) + else: + # TODO: Add status management, change status after fail then user can re-run, then reject request while + # another session is updating the conversation + if conversation_data.is_in_running_status(): + logger.warning("Conversation {} is still running, should aborting new changes") + + logger.info("Got existing conversation {}, starting session.".format(chat_id)) + # update current info into existing conversation + conversation_data = conversation_data.update_from_chat_request(chat_request) + await ConversationDAO.update_conversation(conversation_data) + + return conversation_data + + +async def process_chat_response(session, base_response, user_messages, input_files=[]): + """ + Receive chat response, update the message and status. This function will mark step status to be RUNNING and last + response status to be FINISH + """ + async_chat_generator = session.chat(user_messages, input_files) + chat_response_buffer = None + + async for chat_response in async_chat_generator: + if chat_response_buffer: + # Not the last one, using processing + yield await update_chat_response(chat_response_buffer, base_response, PROCESSING_STATUS) + chat_response_buffer = chat_response + + if chat_response_buffer: + # the last one, using finish + yield await update_chat_response(chat_response_buffer, base_response, FINISH_STATUS) + + +async def update_chat_response(chat_response, base_response, status): + logger.info(f'Update Session Chat Response to conversation.') + message = Message(role=RoleType.Agent, content=chat_response.output_text) + + return update_chat_response_with_message(base_response.copy(), message, status=status).dict() + + +def _get_input_file(chat_request: ChatCompleteRequest): + input_files = [] + if chat_request.code_interpreter and "tos_key" in chat_request.code_interpreter: + input_file = chat_request.code_interpreter["tos_key"] + file_name, tos_path = get_file_name_and_path(input_file) + input_file = MediaFile(file_name=file_name, tos_path=tos_path) + input_files.append(input_file.__dict__()) + return input_files diff --git a/src/infiagent/services/complete_local_test.py b/src/infiagent/services/complete_local_test.py new file mode 100644 index 0000000000000000000000000000000000000000..342d0b2665cc79c625637e98990d71e034d72f05 --- /dev/null +++ b/src/infiagent/services/complete_local_test.py @@ -0,0 +1,58 @@ +import asyncio + +from ..schemas import Message, RoleType +from ..schemas import chat_request_to_response, ChatCompleteRequest, update_chat_response_with_message + +DONE = "DONE" + +FINISH_STATUS = "FINISH" +FAILED_STATUS = "FAILED" +PROCESSING_STATUS = "PROCESSING" + + +def message_generator(messages): + for message in messages: + yield message + + +def update_chat_status_local(async_chat_generator): + """Yields pairs (current_item, is_last) for each item in async_gen.""" + buffered_chat = None + for item in async_chat_generator: + if buffered_chat is not None: + yield buffered_chat, False + buffered_chat = item + if buffered_chat is not None: + yield buffered_chat, True + + +async def chat_local_event_generator(chat_request: ChatCompleteRequest): + """ + Init a chat session and start pushing response back + """ + message1 = Message(role=RoleType.User, content="你好,我是豆包") + message2 = Message(role=RoleType.Agent, content="今天天气很好") + message3 = Message(role=RoleType.Agent, content="再见") + messages = [message1, message2, message3] + + base_response = chat_request_to_response(chat_request) + try: + for message, is_last in update_chat_status_local(message_generator(messages)): + status = FINISH_STATUS if is_last else PROCESSING_STATUS + yield update_chat_response_with_message(base_response.copy(), message, status=status).dict() + await asyncio.sleep(2) + except Exception as e: + failed_message = Message(role=RoleType.System, content="Failed: {}".format(str(e))) + yield update_chat_response_with_message(base_response.copy(), failed_message, status=FAILED_STATUS).dict() + + yield DONE + + +async def chat_local_event(chat_request: ChatCompleteRequest): + """ + Init a chat session and start pushing response back + """ + message = Message(role=RoleType.User, content="Hi") + + base_response = chat_request_to_response(chat_request) + return update_chat_response_with_message(base_response.copy(), message).dict() diff --git a/src/infiagent/tools/__init__.py b/src/infiagent/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a47dc216034d132e6a3d45edbaca3e403df7660b --- /dev/null +++ b/src/infiagent/tools/__init__.py @@ -0,0 +1,8 @@ +from .base_tool import BaseTool +from .code_sandbox import PythonSandBoxToolResponse, AsyncPythonSandBoxTool +try: + import docker +except: + pass +else: + from .code_tool_docker import CodeTool, PythonSandBoxToolResponseDocker \ No newline at end of file diff --git a/src/infiagent/tools/base_tool.py b/src/infiagent/tools/base_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..7097285b393928c301a1423b0a95beaeaa401ff7 --- /dev/null +++ b/src/infiagent/tools/base_tool.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass +from typing import Optional, Type +from abc import ABC +from importlib import import_module + +from ..exceptions.exceptions import InvalidConfigException +from ..utils import Config + + +@dataclass +class BaseToolRequest(ABC): + input_text: Optional[str] + + +@dataclass +class BaseToolResponse(ABC): + output_text: Optional[str] + + +# BaseTool +class BaseTool(ABC): + _name = None + _description = None + + def __init__(self, name, description, **kwargs): + self._name = name + self._description = description + self.setup() + + @property + def name(self): + """Getter for name.""" + return self._name + + @property + def description(self): + """Getter for description.""" + return self._description + + @classmethod + def from_config(cls, config_input, **kwargs): + """Create a BaseTool instance from a config file path or a config data dictionary. + + :param config_input: Either a file path to a config file or a config data dictionary. + :type config_input: str or dict + :param kwargs: Additional keyword arguments to pass to the class constructor. + :return: A BaseTool instance. + :rtype: BaseTool + """ + if isinstance(config_input, str): + # If config_input is a string, assume it's a file path. + config_data = Config.load(config_input) + elif isinstance(config_input, dict): + # If config_input is a dict, use it directly as config_data. + config_data = config_input + else: + raise InvalidConfigException( + f"Invalid config_input type: {type(config_input)}. " + "Expected str (file path) or dict (config data)." + ) + + module_name = config_data['module_name'] + class_name = config_data['class_name'] + module = import_module(module_name) + clazz = getattr(module, class_name) + return clazz(**config_data, **kwargs) + + @classmethod + async def async_from_config(cls, config_input, **params): + """Asynchronously create a BaseTool instance from a config file path or a config data dictionary. + + :param config_input: Either a file path to a config file or a config data dictionary. + :type config_input: str or dict + :param params: Additional parameters to pass to the create method. + :return: A BaseTool instance. + :rtype: BaseTool + """ + + + if isinstance(config_input, str): + # If config_input is a string, assume it's a file path. + config_data = Config.load(config_input) + elif isinstance(config_input, dict): + # If config_input is a dict, use it directly as config_data. + config_data = config_input + else: + raise InvalidConfigException( + f"Invalid config_input type: {type(config_input)}. " + "Expected str (file path) or dict (config data)." + ) + + + module_name = config_data['module_name'] + class_name = config_data['class_name'] + module = import_module(module_name) + clazz = getattr(module, class_name) + + return await clazz.create(config_data, **params) + + @classmethod + async def async_from_config_path(cls, config_path, **params): + return await cls.async_from_config_data(config_data=Config.load(config_path), **params) + + @classmethod + async def async_from_config_data(cls, config_data, **params): + module_name = config_data['module_name'] + class_name = config_data['class_name'] + + module = import_module(module_name) + clazz = getattr(module, class_name) + + return await clazz.create(config_data, **params) + + @classmethod + async def create(cls, config_data, **params): + """ + Async create tool instance. init cannot be async, so wrap async init logic here. + """ + pass + + def setup(self): + pass + + def run(self, req: BaseToolRequest): + pass + + async def async_run(self, req: BaseToolRequest): + """ + Async run tool. + """ + return self.run(req) diff --git a/src/infiagent/tools/code_sandbox/__init__.py b/src/infiagent/tools/code_sandbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26363d6a2062339639ab6ace9e9f038483f3297a --- /dev/null +++ b/src/infiagent/tools/code_sandbox/__init__.py @@ -0,0 +1 @@ +from .python_code_sandbox import PythonSandBoxToolResponse, AsyncPythonSandBoxTool diff --git a/src/infiagent/tools/code_sandbox/python_code_sandbox.py b/src/infiagent/tools/code_sandbox/python_code_sandbox.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a344d6d7f85467cfddfd0aba11558e53a663f5 --- /dev/null +++ b/src/infiagent/tools/code_sandbox/python_code_sandbox.py @@ -0,0 +1,209 @@ +from typing import Union, Dict +from werkzeug.datastructures import FileStorage +from ...tools.base_tool import BaseTool +from ...utils import clean_ansi, get_logger +from jupyter_client import BlockingKernelClient +import json +import os +import queue +import re +import subprocess +import sys +import time +import traceback +from enum import Enum +from ...utils.file_utils import clear_files + +logger = get_logger() + +root_directory = os.path.abspath(__file__) +while 'infiagent' not in os.path.basename(root_directory): + root_directory = os.path.dirname(root_directory) + +WORK_DIR = f'{root_directory}/tmp/ci_workspace' +FILE_DIR = f'{root_directory}/tmp/upload_files' + + +class _Type(Enum): + SUCCESS = 1 + ERROR = 2 + FAIL = 3 + + +class PythonSandBoxToolResponse: + + def __init__(self, + sand_box_response: str, + _type: _Type) -> None: + self._sand_box_response = sand_box_response + self._type = _type + + @property + def output_text(self): + return self._format(self._sand_box_response, self._type) + + @property + def raw_output(self): + return self._sand_box_response + + @classmethod + def _format(cls, sandbox_response, _type): + if _type == _Type.FAIL: + msg = f"\nCode execution error\n" + msg += f"What happened: {sandbox_response}" + else: + msg = "" + if _type == _Type.SUCCESS: + msg += "\nSTDOUT:\n" + msg += f"```python\n{clean_ansi(sandbox_response)}\n```" + "\n" + elif _type == _Type.ERROR: + msg += "\nSTDERR:\n" + msg += f"```python\n{clean_ansi(sandbox_response)}\n```" + "\n" + return msg + + +class AsyncPythonSandBoxTool(BaseTool): + _KERNEL_CLIENTS: Dict[int, BlockingKernelClient] = {} + LAUNCH_KERNEL_PY = (f"import os\nos.chdir('{root_directory}/tmp')\nfrom ipykernel import kernelapp as " + f"app\napp.launch_new_instance()") + + def __init__(self, name, description, **kwargs): + super().__init__(name, description, **kwargs) + self._sandbox_id = None + + @classmethod + async def create(cls, config_data, **params): + # Unpack the config_data dictionary and any additional parameters + instance = cls(name=config_data['name'], description=config_data['description'], **params) + return instance + + @classmethod + def kill_kernels(cls, sandbox_id): + if sandbox_id in AsyncPythonSandBoxTool._KERNEL_CLIENTS: + AsyncPythonSandBoxTool._KERNEL_CLIENTS[sandbox_id].shutdown() + del AsyncPythonSandBoxTool._KERNEL_CLIENTS[sandbox_id] + clear_files(os.path.join(WORK_DIR, sandbox_id)) + clear_files(os.path.join(FILE_DIR, sandbox_id)) + + def _start_kernel(self) -> BlockingKernelClient: + connection_file = os.path.join(WORK_DIR, self.sandbox_id, f'kernel_connection_file_{self.sandbox_id}.json') + launch_kernel_script = os.path.join(WORK_DIR, self.sandbox_id, f'launch_kernel_{self.sandbox_id}.py') + for f in [connection_file, launch_kernel_script]: + if os.path.exists(f): + os.remove(f) + + os.makedirs(os.path.join(WORK_DIR, self.sandbox_id), exist_ok=True) + with open(launch_kernel_script, 'w') as fout: + fout.write(AsyncPythonSandBoxTool.LAUNCH_KERNEL_PY) + + kernel_process = subprocess.Popen([ + sys.executable, + launch_kernel_script, + '--IPKernelApp.connection_file', + connection_file, + '--matplotlib=inline', + '--quiet', + ], + cwd=WORK_DIR) + + # Wait for kernel connection file to be written + while True: + if not os.path.isfile(connection_file): + time.sleep(0.1) + else: + # Keep looping if JSON parsing fails, file may be partially written + try: + with open(connection_file, 'r') as fp: + json.load(fp) + break + except json.JSONDecodeError: + pass + + # Client + kc = BlockingKernelClient(connection_file=connection_file) + kc.load_connection_file() + kc.start_channels() + kc.wait_for_ready() + return kc + + async def set_sandbox_id(self, sandbox_id): + self._sandbox_id = sandbox_id + + @property + def sandbox_id(self): + """Getter for sandbox_id.""" + return self._sandbox_id + + async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]) -> str: + return os.path.join(root_directory, f"tmp/upload_files/{self.sandbox_id}/{file.split('/')[-1]}") + + @staticmethod + def _input_handler(input_code: str) -> str: + # 使用正则表达式查找被三重反引号包围的代码块 + code_blocks = re.findall(r'```(?:python)?\s*(.*?)\s*```', input_code, re.DOTALL) + + # 合并所有找到的代码块 + python_code_cleaned = '\n'.join(code_blocks).strip() + + return python_code_cleaned + + @staticmethod + def _escape_ansi(line: str) -> str: + ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', line) + + @staticmethod + def _execute_code(kc: BlockingKernelClient, code: str) -> PythonSandBoxToolResponse: + kc.wait_for_ready() + kc.execute(code) + result = [] + state = _Type.FAIL + + while True: + finished = False + try: + msg = kc.get_iopub_msg() + msg_type = msg['msg_type'] + logger.info(msg_type) + if msg_type == 'status': + if msg['content'].get('execution_state') == 'idle': + finished = True + elif msg_type == 'execute_result': + text = msg['content']['data'].get('text/plain', '') + result.append(text) + state = _Type.SUCCESS + elif msg_type == 'stream': + text = msg['content']['text'] + result.append(text) + state = _Type.SUCCESS + elif msg_type == 'error': + text = AsyncPythonSandBoxTool._escape_ansi('\n'.join(msg['content']['traceback'])) + result.append(text) + state = _Type.ERROR + except queue.Empty: + text = 'Timeout: Code execution exceeded the time limit.' + result.append(text) + state = _Type.FAIL + finished = True + except Exception: + text = 'The code interpreter encountered an unexpected error.' + result.append(text) + logger.error(''.join(traceback.format_exception(*sys.exc_info()))) + state = _Type.FAIL + finished = True + if finished: + break + output = '\n'.join(result) + return PythonSandBoxToolResponse(sand_box_response=output, _type=state) + + async def async_run(self, req: str): + formatted_input = self._input_handler(req) + if self.sandbox_id in AsyncPythonSandBoxTool._KERNEL_CLIENTS: + kc = AsyncPythonSandBoxTool._KERNEL_CLIENTS[self.sandbox_id] + else: + kc = self._start_kernel() + AsyncPythonSandBoxTool._KERNEL_CLIENTS[self.sandbox_id] = kc + + return self._execute_code(kc, formatted_input) + + diff --git a/src/infiagent/tools/code_tool_docker.py b/src/infiagent/tools/code_tool_docker.py new file mode 100644 index 0000000000000000000000000000000000000000..e0281375105c94f130c58a5f612a2be11bc107df --- /dev/null +++ b/src/infiagent/tools/code_tool_docker.py @@ -0,0 +1,210 @@ +import os +import pathlib +from typing import Tuple, Optional, IO, Union, Dict +import time +from hashlib import md5 +import docker +from ..tools.base_tool import BaseTool, BaseToolRequest, BaseToolResponse +import re +from ..exceptions.exceptions import InputErrorException, SandBoxFileUploadException +from werkzeug.datastructures import FileStorage +from ..utils import get_logger + +logger = get_logger() + +try: + import docker +except ImportError: + docker = None + +WORKING_DIR = os.path.join(os.getcwd(), "tmp/code_space") +OUTPUT_DIR = os.path.join(os.getcwd(), "tmp/output_space") +UPLOAD_PATH = os.path.join(os.getcwd(), "tmp/upload_files") + + +class CodeToolRequest(BaseToolRequest): + """ + Request for Code Tool + """ + def __init__(self, code_str: str): + # code_str = 'import pandas as pd\nimport numpy as np\n'+ code_str + code_blocks = re.findall(r'```(?:python)?\s*(.*?)\s*```', code_str, re.DOTALL) + python_code_cleaned = '\n'.join(code_blocks).strip() + self.code = python_code_cleaned + +class PythonSandBoxToolResponseDocker: + def __init__(self, formatter, raw_output) -> None: + self.formatter = formatter + self.raw_output = raw_output + + @property + def output_text(self): + return self.formatter.format(self.raw_output) + + +class CodeToolResponse(BaseToolResponse): + """ + Response for Code Tool + """ + def __init__(self, exit_code: int, log: str, output_dir: str): + self.exit_code = exit_code + self.log = log + self.output_dir = output_dir + self.output_text = log + + def to_dict(self): + return { + "exit_code": self.exit_code, + "log": self.log, + "output_dir": self.output_dir + } + + +class CodeTool(BaseTool): + """ + Code Tool for code execution + """ + def __init__(self, + name: Optional[str] = "Code Tool", + description: Optional[str] = "tool for code_exec", + # code_tool_id: Optional[str] = "code", + image: Optional[str] = "myimg", + time_out: Optional[int] = 60, + work_dir: Optional[str] = WORKING_DIR, + output_dir: Optional[str] = OUTPUT_DIR, + **kwargs + ): + super().__init__(name, description, **kwargs) + self._client = docker.from_env() + self._image = image + self._time_out = time_out + self._work_dir = work_dir + self._output_dir = output_dir + self._upload_file_name = None + self._upload_file_path = None + self._code_idx = md5(str(time.time()).encode()).digest().hex() + self._log_len = 0 + + @classmethod + async def create(cls, config_data, **params): + # Unpack the config_data dictionary and any additional parameters + instance = cls(name=config_data['name'], description=config_data['description'], **params) + return instance + + async def set_sandbox_id(self, sandbox_id): + self._sandbox_id = sandbox_id + + @property + def sandbox_id(self): + """Getter for sandbox_id.""" + return self._sandbox_id if self._sandbox_id else None + + + async def async_run(self, req: str): + req = CodeToolRequest(req) + code = req.code + if code is None: + return "No code to execute", 1, "" + + # path and file name for python script + abs_path = pathlib.Path(self._work_dir).absolute() + code_hash = self._code_idx + file_name = f"exec_code_{code_hash}.py" + file_path = os.path.join(self._work_dir, file_name) + self._file_path = file_path + file_dir = os.path.dirname(file_path) + self._file_dir = file_dir + os.makedirs(file_dir, exist_ok=True) + os.makedirs(OUTPUT_DIR, exist_ok=True) + if self._upload_file_name: + upload_file_path = os.path.join(UPLOAD_PATH, self._upload_file_name) + + # write code to file + with open(file_path, "a", encoding="utf-8") as fout: + fout.write(code) + cmd = f'python3 {file_name}' + + # create docker container + start_time = time.time() + if self._upload_file_name: + container = self._client.containers.run( + image=self._image, + command=cmd, + detach=True, + working_dir="/workspace", + mem_limit='1024m', + volumes={abs_path: {'bind': '/workspace','mode': 'rw'}, + upload_file_path: {'bind': f'/tmp/upload_files/{self._upload_file_name}','mode': 'rw'}}, + ) + else: + container = self._client.containers.run( + image=self._image, + command=cmd, + detach=True, + working_dir="/workspace", + mem_limit='10m', + volumes={abs_path: {'bind': '/workspace','mode': 'rw'}}, + ) + + # hold for time_out seconds + while container.status != "exited" and time.time() - start_time < self._time_out: + container.reload() + + # if time out, stop and remove container + if container.status != "exited": + container.stop() + container.remove() + return "TIMEOUT", 1, "" + + # save log to file + logs = container.logs().decode("utf-8").rstrip() + with open(os.path.join(file_dir, f'log.txt'), 'w') as log_file: + log_file.write(logs) + new_len = len(logs) + logs = logs[self._log_len:] + self._log_len = new_len + + exit_code = container.attrs["State"]["ExitCode"] + container.remove() + + # save files to output space and rmv files in working space + output_dir = os.path.join(OUTPUT_DIR, f'output_{code_hash}') + self._output_dir = output_dir + # os.makedirs(output_dir, exist_ok=True) + # os.rename(file_path, os.path.join(file_dir, 'exec_code.py')) + # for f in os.listdir(abs_path): + # os.rename(os.path.join(abs_path, f), os.path.join(output_dir, f)) + # os.rmdir(abs_path) + + response = CodeToolResponse(exit_code, logs, output_dir) + + return response + + async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]) -> str: + if isinstance(file, str): + logger.info(f"Upload File As FilePath: {file}") + file_path = await self.upload_file(file) + else: + err_msg = f"Invalid file input type. Expected str, FileStorage, or Dict. Got {type(file)}" + logger.error(err_msg) + raise InputErrorException(err_msg) + + return file_path + + async def upload_file(self, file_path: str): + file_name = file_path.split("/")[-1] # Extract the file name from the path + self._upload_file_path = file_path + self._upload_file_name = file_name + + return file_path + + async def save_file(self): + output_dir = self._output_dir + file_path = self._file_path + file_dir = self._file_dir + abs_path = pathlib.Path(self._work_dir).absolute() + os.makedirs(output_dir, exist_ok=True) + os.rename(file_path, os.path.join(file_dir, 'exec_code.py')) + for f in os.listdir(abs_path): + os.rename(os.path.join(abs_path, f), os.path.join(output_dir, f)) + os.rmdir(abs_path) diff --git a/src/infiagent/utils/__init__.py b/src/infiagent/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0c00c54152783e20930eac0100b49ae4e0c10f --- /dev/null +++ b/src/infiagent/utils/__init__.py @@ -0,0 +1,8 @@ +from .logger import * +from .config import * +from .loader import * +from .file_utils import * +from .session_utils import * +from .string_utils import * +from .common_utils import * +from .system_messages import * diff --git a/src/infiagent/utils/common_utils.py b/src/infiagent/utils/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/infiagent/utils/config.py b/src/infiagent/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a27673d1606454f69652825697e94bf2c17b403b --- /dev/null +++ b/src/infiagent/utils/config.py @@ -0,0 +1,72 @@ +import yaml +from typing import Dict, AnyStr, Union, Any +from pathlib import Path + +from ..prompt import SimpleReactPrompt, ZeroShotReactPrompt +from .logger import get_logger + +logger = get_logger() + + +class Config: + """ + A class for loading and creating configuration dictionaries from files or dictionaries. + """ + + @staticmethod + def _prompt_constructor(loader, node): + value = node.value + if value == "SimpleReactPrompt": + return SimpleReactPrompt() + elif value == "ZeroShotReactPrompt": + return ZeroShotReactPrompt() + else: + logger.warning(f"Unknown prompt name: {value}. use default SimpleReactPrompt") + return SimpleReactPrompt() + + @staticmethod + def load(path: Union[Path, AnyStr]) -> Dict[AnyStr, Any]: + """ + Load a configuration dictionary from a YAML file. + + :param path: The path to the configuration file. + :type path: Union[Path, AnyStr] + :raises FileNotFoundError: If the file is not found. + :raises yaml.YAMLError: If a YAML error occurred while loading the file. + :raises Exception: If an unexpected error occurred. + :return: A dictionary containing the configuration. + :rtype: Dict[AnyStr, Any] + """ + # logger the start of the loading process + logger.info(f"Starting to load configuration from {path}") + + # Register the custom prompt constructor with PyYAML + yaml.add_constructor('!prompt', Config._prompt_constructor) + + try: + with open(path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + logger.info(f"Successfully loaded configuration from {path}") + return config + except FileNotFoundError: + logger.error(f"Config file {path} not found") + raise FileNotFoundError(f"Config file {path} not found") + except yaml.YAMLError as e: + logger.error(f"YAML error occurred while loading the configuration: {str(e)}", exc_info=True) + raise yaml.YAMLError(e) + except Exception as e: + logger.error(f"An unexpected error occurred: {str(e)}", exc_info=True) + raise Exception(e) + + @staticmethod + def from_dict(config: Dict[AnyStr, Any]) -> Dict[AnyStr, Any]: + """ + Create a configuration dictionary from a Python dictionary. + + :param config: A dictionary containing configuration parameters. + :type config: Dict[AnyStr, Any] + :return: A dictionary containing the configuration. + :rtype: Dict[AnyStr, Any] + """ + logger.info(f"Creating Config from dictionary") + return config diff --git a/src/infiagent/utils/file_utils.py b/src/infiagent/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5cda64a78bb61c675878fbdc7edd1c52ba479f77 --- /dev/null +++ b/src/infiagent/utils/file_utils.py @@ -0,0 +1,156 @@ +import csv +import io +import logging +import os +import shutil + +import chardet + +from .logger import get_logger + +root_directory = os.path.abspath(__file__) +while 'infiagent' not in os.path.basename(root_directory): + root_directory = os.path.dirname(root_directory) + +TEMP_FILE_UPLOAD_DIR = f"{root_directory}/tmp/upload_files/" +MAX_INPUT_FILE_SIZE = 1024 * 1024 * 1024 +SAMPLE_FILE_SIZE = 2048 +CSV_DEFAULT_DELIMITER = "," +CSV_DELIMITERS = [',', '\t', ';', '|', ' '] + +logger = get_logger() + + +def clear_files(upload_file_dir): + for filename in os.listdir(upload_file_dir): + file_path = os.path.join(upload_file_dir, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Failed to delete %s. Error: %s' % (file_path, e)) + shutil.rmtree(upload_file_dir) + + +def upload_files(uploaded_files, sandbox_id): + uploaded_files_list = [] + + if not uploaded_files: + logging.info("No file upload") + return uploaded_files_list + else: + logging.info("Got {} files to upload.".format(len(uploaded_files))) + + FILE_DIR = os.path.join(TEMP_FILE_UPLOAD_DIR, sandbox_id) + if os.path.exists(FILE_DIR): + clear_files(FILE_DIR) + else: + # if the demo_folder directory is not present then create it. + os.makedirs(FILE_DIR) + + for uploaded_file in uploaded_files: + # 获取文件的基本信息 + file_details = {"FileName": uploaded_file.name, "FileType": uploaded_file.type, "FileSize": uploaded_file.size} + logging.info(file_details) + + uploaded_files_list.append(_process_files(uploaded_file, FILE_DIR)) + + logging.info("All files saved to disk.") + + return uploaded_files_list + + +def _process_files(uploaded_file, output_dir): + # Check if file size is more than 1 GB + if uploaded_file.size > MAX_INPUT_FILE_SIZE: + raise ValueError(f"File {uploaded_file.name} is larger than 1 GB") + + # Check if the file is a CSV and if the delimiter meets requirement + if uploaded_file.name.endswith('.csv'): + return _process_local_csv_file(uploaded_file, output_dir) + else: + new_file_path = os.path.join(output_dir, uploaded_file.name) + with open(new_file_path, 'wb') as new_file: + new_file.write(uploaded_file.getvalue()) + return new_file_path + + +def _process_local_csv_file(uploaded_file, output_dir): + """ + Process the uploaded file to convert the delimiter if needed and save the content in the output directory. + + Args: + - uploaded_file: File-like object of the uploaded file + - output_dir (str): Directory where the processed file should be saved + + Returns: + - str: The path to the saved file + """ + # Decode the content of the uploaded file + file_content = uploaded_file.read() + content_stream = io.BytesIO(file_content) + + # Process the content stream + converted_file_stream, converted = convert_delimiter_to_comma(content_stream) + + # Construct the output path + new_file_path = os.path.join(output_dir, uploaded_file.name) + + # Write the processed content to the output path + with open(new_file_path, 'wb') as file: + file.write(converted_file_stream.getvalue()) + + return new_file_path + + +def convert_delimiter_to_comma(content_stream: io.BytesIO) -> (io.BytesIO, bool): + """ + Detects the delimiter of a CSV content stream and converts it to comma if it's not already. + + Args: + - content_stream (io.BytesIO): Stream containing CSV content + + Returns: + - tuple: New content stream with updated delimiter, flag indicating if conversion was done + """ + sample = content_stream.read(SAMPLE_FILE_SIZE) + content_stream.seek(0) + + # Use chardet to detect the encoding + detected = chardet.detect(sample) + encoding = detected.get('encoding', 'utf-8') or 'utf-8' + decoded_sample = sample.decode(encoding, errors='replace') + + sniffer = csv.Sniffer() + try: + delimiter = sniffer.sniff(decoded_sample, delimiters=''.join(CSV_DELIMITERS)).delimiter + except (csv.Error, UnicodeDecodeError) as e: + logger.warning("Unable to confidently determine the delimiter for the CSV content. Return original file. " + "error: {}".format(str(e))) + return content_stream, False + + if delimiter == CSV_DEFAULT_DELIMITER: + logger.info("Original CSV file delimiter is ',', no need to convert.") + return content_stream, False + + logger.info("Original CSV file delimiter is '{}', converting it to ','.".format(delimiter)) + reader = csv.reader(content_stream.getvalue().decode('utf-8').splitlines(), delimiter=delimiter) + temp_output = io.StringIO() # Temporary StringIO to hold string representation + writer = csv.writer(temp_output, delimiter=CSV_DEFAULT_DELIMITER, lineterminator='\n') + + for row in reader: + writer.writerow(row) + + # Convert StringIO value to bytes and write to BytesIO stream + output_stream = io.BytesIO() + output_stream.write(temp_output.getvalue().encode('utf-8')) + output_stream.seek(0) + return output_stream, True + + +def get_file_name_and_path(input_file: str): + file_name = input_file.split("/")[-1] + tos_path = input_file.replace(file_name, "") + return file_name, tos_path \ No newline at end of file diff --git a/src/infiagent/utils/loader.py b/src/infiagent/utils/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ed44c4e3c0af5fa479da63ffcfce0fc054319b9c --- /dev/null +++ b/src/infiagent/utils/loader.py @@ -0,0 +1,115 @@ +import importlib +import os +from pathlib import Path +from typing import IO, Any + +import yaml + + +class Loader(yaml.SafeLoader): + """ + A custom YAML loader that adds support for various custom tags: + + - !include: includes a YAML file as a subdocument + - !prompt: returns a prompt class based on the specified string + - !tool: returns a tool class based on the specified string + - !env: returns the value of an environment variable + - !file: returns the contents of a file + """ + def __init__(self, stream: IO[Any]) -> None: + """ + Initializes a new instance of the Loader class. + :param stream: The stream to load YAML from. + :type stream: IOBase + """ + self._root = Path(stream.name).resolve().parent + super(Loader, self).__init__(stream) + self.add_constructor("!include", Loader.include) + self.add_constructor("!prompt", Loader.prompt) + self.add_constructor("!tool", Loader.tool) + self.add_constructor("!env", Loader.env) + self.add_constructor("!file", Loader.file) + + def include(self, node: yaml.Node) -> Any: + """ + Loads a YAML file from a path relative to the current file. Use this tag to include other agent configs as plugins. + + :param node: The YAML node to be loaded. + :type node: yaml.Node + :return: The loaded YAML file. + :rtype: Any + """ + filename = Path(self.construct_scalar(node)) + if not filename.is_absolute(): + filename = self._root / filename + with open(filename, 'r') as f: + return yaml.load(f, Loader) + + def prompt(self, node: yaml.Node) -> Any: + """ + Returns a PromptTemplate class based on the specified string. + + :param node: The YAML node representing the prompt string. + :type node: yaml.Node + :return: The prompt class. + :rtype: type + :raises AssertionError: If the resolved prompt class is not a subclass of PromptTemplate. + """ + from ..prompt import PromptTemplate, SimpleReactPrompt, ZeroShotReactPrompt + prompt = self.construct_scalar(node) + if '.' in prompt: + _path = prompt.split('.') + module = importlib.import_module('.'.join(_path[:-1])) + prompt_cls = getattr(module, _path[-1]) + else: + prompt_cls = eval(prompt) + assert issubclass(prompt_cls.__class__, PromptTemplate) + return prompt_cls + + def tool(self, node: yaml.Node) -> Any: + """ + Loads a Custom BaseTool class from a path relative to the current file. + + :param node: The YAML node to be loaded. + :type node: yaml.Node + :return: The loaded BaseTool class. + :rtype: Any + """ + from ..tools import BaseTool, PythonSandBoxTool + + tool = self.construct_scalar(node) + if '.' in tool: + _path = tool.split('.') + module = importlib.import_module('.'.join(_path[:-1])) + tool_cls = getattr(module, _path[-1]) + else: + tool_cls = eval(tool) + + assert issubclass(tool_cls, BaseTool) + return tool_cls + + def env(self, node: yaml.Node) -> Any: + """ + Loads an environment variable from the current environment, defaults to an empty string if the variable is not set. + + :param node: The YAML node to be loaded. + :type node: yaml.Node + :return: The loaded environment variable. + :rtype: Any + """ + return os.environ.get(self.construct_scalar(node), "") + + def file(self, node: yaml.Node) -> Any: + """ + Loads any readable file from a path relative to the current file. + + :param node: The YAML node to be loaded. + :type node: yaml.Node + :return: The loaded file. + :rtype: Any + """ + filename = Path(self.construct_scalar(node)) + if not filename.is_absolute(): + filename = self._root / filename + with open(filename, 'r') as f: + return f.read().strip() \ No newline at end of file diff --git a/src/infiagent/utils/logger.py b/src/infiagent/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..400eefd60ee230b13de1bafcbd67ac9e80322a7d --- /dev/null +++ b/src/infiagent/utils/logger.py @@ -0,0 +1,38 @@ + +# Custom Logger Adapter to include X-Tt-Logid +import logging +from contextvars import ContextVar + + +log_id_var: ContextVar[str] = ContextVar("log_id", default="") + + +class ContextualLoggerAdapter(logging.LoggerAdapter): + def process(self, msg, kwargs): + log_id = log_id_var.get() + return f"[{log_id}] : {msg}", kwargs + + +def init_logging(): + """ + Initialize logging configuration. + """ + # Basic logging configuration with your specified settings + # config_default() + openai_logger = logging.getLogger("openai") + openai_logger.setLevel(logging.WARNING) + + logging.basicConfig( + level=logging.INFO, + datefmt=r'%Y/%m/%d %H:%M:%S', + format=r'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s', + ) + + +def get_logger() -> logging.LoggerAdapter: + """ + Retrieve a logger instance configured with X-Tt-Logid. + """ + logger = logging.getLogger("infiagent_logger") + logger.setLevel(logging.INFO) + return ContextualLoggerAdapter(logger, {}) diff --git a/src/infiagent/utils/session_utils.py b/src/infiagent/utils/session_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2763756ca696bc74648cf8c737a4304446c8049e --- /dev/null +++ b/src/infiagent/utils/session_utils.py @@ -0,0 +1,39 @@ +from .logger import get_logger + +logger = get_logger() + +MODEL_NAME_TO_CONFIG = { + "OPEN_AI": "../configs/agent_configs/react_agent_gpt4_async.yaml", + "AZURE_OPEN_AI": "../configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml", + "AZURE_GPT35_TURBO": "../configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml", + "AZURE_GPT4": "../configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml", + "LLAMA": "../configs/agent_configs/react_agent_llama_async.yaml", + "OPT": "../configs/agent_configs/react_agent_opt_async.yaml", + +} + + +def get_model_config_path(input_model_name): + if input_model_name is None: + model_name = "openai" + else: + model_name = input_model_name + + # check if same model name + if model_name in MODEL_NAME_TO_CONFIG: + return MODEL_NAME_TO_CONFIG[model_name] + + # check if converted to capital letters + if model_name.upper() in MODEL_NAME_TO_CONFIG: + return MODEL_NAME_TO_CONFIG[model_name.upper()] + + if "openai" in model_name: + return MODEL_NAME_TO_CONFIG["AZURE_OPEN_AI"] + + elif "llama" in model_name: + return MODEL_NAME_TO_CONFIG["LLAMA"] + elif "opt" in model_name: + return MODEL_NAME_TO_CONFIG["OPT"] + else: + logger.warning("unknown model name, use official.") + return MODEL_NAME_TO_CONFIG["AZURE_OPEN_AI"] diff --git a/src/infiagent/utils/string_utils.py b/src/infiagent/utils/string_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51379b60b0e74638d7da24d9d3413c0fa227d5c8 --- /dev/null +++ b/src/infiagent/utils/string_utils.py @@ -0,0 +1,74 @@ +import os +import random +import re +import string +from urllib.parse import urlparse + + +def is_image_link(url): + image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'] + if any(url.endswith(ext) for ext in image_extensions): + return True + + return False + + +def extract_filename_from_url(url): + parsed_url = urlparse(url) + return os.path.basename(parsed_url.path) + + +def clean_ansi(text): + return re.sub(r'\x1b\[[0-9;]*[a-zA-Z]', '', text) + + +# Generate a random string of specified length +def generate_random_string(length=12): + characters = string.ascii_letters + string.digits # both upper and lowercase letters and digits + return ''.join(random.choice(characters) for i in range(length)) + + +def extract_urls(text): + url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' + all_matched_urls = re.findall(url_pattern, text) + return all_matched_urls + + +def replace_latex_format(s): + # replace \\(...\\) format + s = re.sub(r'\\\((.*?)\\\)', r'$$\1$$', s, flags=re.DOTALL) + + # replace \\[...\\] format + s = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', s, flags=re.DOTALL) + + return s + + +def extract_and_replace_url(text): + url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' + all_matched_urls = re.findall(url_pattern, text) + # print(f"matched_urls: {all_matched_urls}") + text = text.replace("Generated an image: ", "") + text = text.replace("Generated files on server: ", "") + + new_urls = [] + for extracted_url in all_matched_urls: + if is_image_link(extracted_url): + new_url = f'![The Image]({extracted_url})' + else: + filename = extract_filename_from_url(extracted_url) + new_url = f'[{filename}]({extracted_url})' + + new_urls.append(new_url) + text = text.replace(extracted_url, "") + + text = re.sub(r'\x1b\[[0-9;]*[a-zA-Z]', '', text) + + return f'```python\n{text}```' + "\n" + "\n".join(new_urls) + + +def contains_chinese(input_context: str) -> bool: + for char in input_context: + if '\u4e00' <= char <= '\u9fff': + return True + return False diff --git a/src/infiagent/utils/system_messages.py b/src/infiagent/utils/system_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9414eb9f6fae4f89ef09e9f462765d69f0fdfa --- /dev/null +++ b/src/infiagent/utils/system_messages.py @@ -0,0 +1,27 @@ +SYSTEM_MESSAGE_PREFIX_EN = '[SYSTEM NOTIFICATION]' +SYSTEM_MESSAGE_PREFIX_CH = '【系统提示】' + +FILE_CONVERTED_EN = ( + f"\n{SYSTEM_MESSAGE_PREFIX_EN}\n" + "If you are seeing this message, it indicates that the CSV file you provided did not use a comma as its " + "delimiter. We have automatically detected the delimiter of your file and converted it to a comma-separated " + "format for processing. This means that the processed output may differ from your original input in terms of " + "delimiters. To ensure delimiter consistency, please convert your file to a comma-separated format before " + "uploading." +) + +FILE_CONVERTED_CH = ( + f"\n{SYSTEM_MESSAGE_PREFIX_CH}\n" + "如果您看到此消息,表示您提供的CSV文件的分隔符并非逗号(',')。我们已自动检测到您文件的分隔符,并将其转换为逗号分隔的格式进行处理。" + "这意味着处理后的输出分隔符可能与您原始的输入不同。为确保文件的一致性,请在上传之前将您的CSV文件转换为逗号分隔的格式。" +) + +TOOL_INPUT_PREFIX_EN = f"{SYSTEM_MESSAGE_PREFIX_EN} We need to execute with python sandbox with the following code:" +OBSERVATION_PREFIX_EN = f"{SYSTEM_MESSAGE_PREFIX_EN} Running the above tool with the following response: " +TOOL_INPUT_PREFIX_CN = f"{SYSTEM_MESSAGE_PREFIX_CH} 执行如下代码: " +OBSERVATION_PREFIX_CN = f"{SYSTEM_MESSAGE_PREFIX_CH} 代码执行结果为: " +AGENT_FAILED_EN = f"{SYSTEM_MESSAGE_PREFIX_EN} Sorry Agent unable to answer this question due to LLM fail.\n" +AGENT_FAILED_CN = f"{SYSTEM_MESSAGE_PREFIX_CH} 对不起,模型暂时无法回答这个问题.\n" +AGENT_EXCEED_MAX_RETRY_EN = f"{SYSTEM_MESSAGE_PREFIX_EN} Sorry agent unable to answer the questions within max " \ + f"retry, please try another question." +AGENT_EXCEED_MAX_RETRY_CN = f"{SYSTEM_MESSAGE_PREFIX_CH} 对不起, 模型暂时无法在规定时间内回答这个问题,请换一个问题重试." diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/integration_tests/sample_file.txt b/tests/integration_tests/sample_file.txt new file mode 100644 index 0000000000000000000000000000000000000000..b4f8b3659348ce7ee18871b9c25ac0889a9e974c --- /dev/null +++ b/tests/integration_tests/sample_file.txt @@ -0,0 +1 @@ +This is a sample file \ No newline at end of file diff --git a/tests/integration_tests/test_sandbox_client.py b/tests/integration_tests/test_sandbox_client.py new file mode 100644 index 0000000000000000000000000000000000000000..45014ec47fd368702c0946416e7823ef4efe4cc5 --- /dev/null +++ b/tests/integration_tests/test_sandbox_client.py @@ -0,0 +1,108 @@ +"""Integration test for code sandbox client""" + +import pytest + +import src.schemas.sandbox_models as dm +from src.tools.code_sandbox.async_sandbox_client import AsyncSandboxClient + +# Assuming you have some actual session IDs and files to test with. +# For the purpose of this example, I'm using dummy values. +SAMPLE_SESSION_ID = "test_session_id" +SAMPLE_FILE_PATH = "tests/integration_tests/sample_file.txt" +SAMPLE_FILE_NAME = "sample_file.txt" +SAMPLE_IMAGE_PATH = "tests/integration_tests/sample_image.jpg" +SAMPLE_IMAGE_NAME = "sample_image.jpg" + + +def test_run_code_success(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + + result = client.run_code("print('Hello World!')", timeout=10) + assert isinstance(result, dm.RunCodeOutput) + assert result.data is not None + assert result.data.is_partial == False + assert result.data.result is not None + code_output = result.data.result.code_output_result + assert len(code_output) == 1 + assert code_output[0].__type == "stdout" + assert code_output[0].content == "Hello World!\n" + + +def test_run_code_timeout(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + + result = client.run_code("while True:\n pass", timeout=1) + assert isinstance(result, dm.RunCodeOutput) + assert result.data.is_partial == True + assert result.data.result is not None + assert result.message == "the code doesn't finish in timeout value 1" + + +def test_run_code_error(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + + result = client.run_code("print('hello', error='abc')", timeout=10) + assert isinstance(result, dm.RunCodeOutput) + assert result.data is not None + assert result.data.is_partial == False + assert result.data.result is not None + code_output = result.data.result.code_output_result + assert len(code_output) == 1 + assert code_output[0].__type == "stderr" + + +def test_upload_download_text_file(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + result = client.upload_file(SAMPLE_FILE_PATH, SAMPLE_FILE_NAME) + assert isinstance(result, dm.UploadOutput) + assert result.message == "succeed" + assert result.data == f"/mnt/{SAMPLE_FILE_NAME}" + + result = client.download_file(SAMPLE_FILE_NAME) + assert isinstance(result, dm.DownloadSuccessOutput) + with open(SAMPLE_FILE_PATH, "r") as f: + f_content = f.read() + assert result.content == f_content + + +def test_upload_download_image_file(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + result = client.upload_file(SAMPLE_IMAGE_PATH, SAMPLE_IMAGE_NAME) + assert isinstance(result, dm.UploadOutput) + assert result.message == "succeed" + assert result.data == f"/mnt/{SAMPLE_IMAGE_NAME}" + + result = client.download_file(SAMPLE_IMAGE_NAME) + assert isinstance(result, dm.DownloadSuccessOutput) + with open(SAMPLE_IMAGE_PATH, "rb") as f: + f_content = f.read() + assert result.content.encode() == f_content + + +def test_download_nonexist_file(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + result = client.download_file("this_file_should_not_exist") + + assert isinstance(result, dm.ErrorResponse) + + +def test_heartbeat(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + result = client.heartbeat() + + assert isinstance(result, (dm.HeartbeatOutput, dm.ErrorResponse)) + + +def test_refresh_sandbox(): + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + result = client.refresh_sandbox() + + assert isinstance(result, (dm.RefreshSandboxOutput, dm.ErrorResponse)) + + +# A session scope fixture to ensure the sandbox is refreshed after all tests +@pytest.fixture(scope='session', autouse=True) +def refresh_sandbox_after_all_tests(): + yield + client = AsyncSandboxClient(SAMPLE_SESSION_ID) + client.refresh_sandbox() \ No newline at end of file diff --git a/tests/src/__init__.py b/tests/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/src/agent/__init__.py b/tests/src/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/src/agent/test_base_agent.py b/tests/src/agent/test_base_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..03c8c088c2f18bbf8417962abbca1ad64ade1e25 --- /dev/null +++ b/tests/src/agent/test_base_agent.py @@ -0,0 +1,89 @@ +import unittest +from abc import ABC +from unittest import mock + +from src.agent import BaseAgent +from src.exceptions.exceptions import InputErrorException +from src.llm import BaseLLM +from src.prompt import PromptTemplate +from src.schemas import AgentType, AgentOutput +from src.tools import BaseTool + + +class SampleBaseAgent(BaseAgent): + def run(self, *args, **kwargs) -> AgentOutput: + pass + + +class SampleBaseTool(BaseTool, ABC): + def run(self, req): + pass + + async def async_run(self, req): + pass + + +class TestBaseAgent(unittest.TestCase): + + def setUp(self): + self.mock_llm = mock.create_autospec(BaseLLM) + self.mock_prompt_template = mock.create_autospec(PromptTemplate) + self.agent = SampleBaseAgent(name='TestAgent', type=AgentType.react, version='1.0', + description='Test Description', prompt_template=self.mock_prompt_template) + self.tool = SampleBaseTool("test_tool", "test_tool") + self.agent.add_plugin('test_tool', self.tool) + + def test_properties(self): + self.assertEqual(self.agent.name, 'TestAgent') + self.assertEqual(self.agent.type, AgentType.react) + self.assertEqual(self.agent.version, '1.0') + self.assertEqual(self.agent.description, 'Test Description') + self.assertEqual(self.agent.prompt_template, self.mock_prompt_template) + + # For llm setter + def test_llm_setter_happy_path(self): + self.agent.llm = self.mock_llm + self.assertEqual(self.agent.llm, self.mock_llm) + + def test_llm_setter_input_error(self): + with self.assertRaises(InputErrorException): + self.agent.llm = 'invalid_llm' + + # For add_plugin + def test_add_plugin_happy_path(self): + self.agent.add_plugin('test_tool', 'test_tool_instance') + self.assertIn('test_tool', self.agent.plugins_map) + + def test_add_plugin_input_error(self): + with self.assertRaises(InputErrorException): + self.agent.add_plugin('', None) + + def test_get_prompt_template_dict(self): + # Case: Input is a dictionary + with mock.patch.object(BaseAgent, '_parse_prompt_template', return_value=self.mock_prompt_template): + result = self.agent._get_prompt_template({'test_key': 'dict'}) + self.assertEqual(result, {'test_key': self.mock_prompt_template}) + + def test_get_prompt_template_instance(self): + # Case: Input is a PromptTemplate instance + prompt_instance = PromptTemplate(input_variables=["foo"], template="Say {foo}") + result = self.agent._get_prompt_template(prompt_instance) + self.assertEqual(result, prompt_instance) + + def test_clear(self): + # assuming clear method does nothing as per the provided implementation + self.agent.clear() + + def test_get_plugin_tool_function(self): + function_map = self.agent.get_plugin_tool_function() + self.assertIn('test_tool', function_map) + self.assertEqual(function_map['test_tool'], self.tool.run) + + def test_get_plugin_tool_async_function(self): + function_map = self.agent.get_plugin_tool_async_function() + self.assertIn('test_tool', function_map) + self.assertEqual(function_map['test_tool'], self.tool.async_run) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/src/db/__init__.py b/tests/src/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/src/db/test_conversation_do.py b/tests/src/db/test_conversation_do.py new file mode 100644 index 0000000000000000000000000000000000000000..e019ddb36aeb99b1cfdd0d40ef17f4059b43896d --- /dev/null +++ b/tests/src/db/test_conversation_do.py @@ -0,0 +1,126 @@ +import unittest +from datetime import datetime + +from src.db.conversation_do import ConversationDO, SandboxStatus, ConversationStatus, MessageConf +from src.schemas import RoleType + + +class TestConversationDO(unittest.TestCase): + + def setUp(self): + # Sample data for testing with the new fields + self.sample_data = { + "conversation_id": "12345", + "messages": [{"role": "user", "content": "Hello"}], + "input_files": [{"name": "file1.txt"}], + "output_files": [{"name": "file2.txt"}], + "sandbox_id": "sandbox_sample", + "sandbox_status": "RUNNING", + "create_time": datetime.utcnow(), + "update_time": datetime.utcnow(), + "user": "test_user_v3", + "model_name": "sample_model", + "model_conf_path": "path/to/conf", + "llm_name": "sample_llm", + "agent_type": "sample_agent", + "request_id": "request_sample", + "dead_sandbox_ids": ["dead1", "dead2"], + "status": "RUNNING" + } + + def test_init(self): + conversation = ConversationDO(**self.sample_data) + self.assertEqual(conversation.conversation_id, "12345") + self.assertEqual(len(conversation.messages), 1) + self.assertEqual(conversation.messages[0].role, RoleType.User) + self.assertEqual(conversation.messages[0].content, "Hello") + + def test_to_dict(self): + conversation = ConversationDO(**self.sample_data) + data = conversation.to_dict() + self.assertEqual(data["conversation_id"], "12345") + self.assertEqual(data["sandbox_status"], "RUNNING") + self.assertEqual(len(data["messages"]), 1) + self.assertEqual(data["messages"][0]["role"], 0) + self.assertEqual(data["messages"][0]["content"], "Hello") + + def test_from_dict(self): + data = self.sample_data.copy() + conversation = ConversationDO.from_dict(data) + self.assertEqual(conversation.conversation_id, "12345") + self.assertEqual(len(conversation.messages), 1) + self.assertEqual(conversation.messages[0].role, RoleType.User) + self.assertEqual(conversation.messages[0].content, "Hello") + + def test_update(self): + conversation = ConversationDO(**self.sample_data) + updated_data = { + "sandbox_status": "KILLED", + "user": "updateduser", + "messages": [{"role": "Agent", "content": "Hi"}] + } + conversation.update(updated_data) + self.assertEqual(conversation.sandbox_status, SandboxStatus.KILLED) + self.assertEqual(conversation.user, "updateduser") + self.assertEqual(len(conversation.messages), 1) + self.assertEqual(conversation.messages[0].role, RoleType.Agent) + self.assertEqual(conversation.messages[0].content, "Hi") + + def test_invalid_status(self): + # Testing invalid sandbox_status + data = self.sample_data.copy() + data["sandbox_status"] = "INVALID_STATUS" + conversation = ConversationDO(**data) + self.assertEqual(conversation.sandbox_status, SandboxStatus.UNKNOWN) + + # Testing invalid conversation status + data["status"] = "INVALID_STATUS" + conversation = ConversationDO.from_dict(data) + self.assertEqual(conversation.status, ConversationStatus.UNKNOWN) + + def test_is_in_running_status(self): + conversation = ConversationDO(**self.sample_data) + self.assertTrue(conversation.is_in_running_status()) + conversation.status = ConversationStatus.COMPLETED + self.assertFalse(conversation.is_in_running_status()) + + def test_message_conf(self): + # Test initialization with a MessageConf object + conf = MessageConf(temperature=0.9, top_p=0.8, top_k=5) + conversation = ConversationDO(message_conf=conf, **self.sample_data) + self.assertEqual(conversation.message_conf.__temperature, 0.9) + self.assertEqual(conversation.message_conf.__top_p, 0.8) + self.assertEqual(conversation.message_conf.__top_k, 5) + + # Test initialization with a dictionary + conf_dict = {"temperature": 0.9, "top_p": 0.8, "top_k": 5} + conversation = ConversationDO(message_conf=conf_dict, **self.sample_data) + self.assertEqual(conversation.message_conf.__temperature, 0.9) + self.assertEqual(conversation.message_conf.__top_p, 0.8) + self.assertEqual(conversation.message_conf.__top_k, 5) + + def test_to_dict_with_message_conf(self): + conf = MessageConf(temperature=0.9, top_p=0.8, top_k=5) + conversation = ConversationDO(message_conf=conf, **self.sample_data) + data = conversation.to_dict() + self.assertEqual(data["message_conf"]["temperature"], 0.9) + self.assertEqual(data["message_conf"]["top_p"], 0.8) + self.assertEqual(data["message_conf"]["top_k"], 5) + + def test_from_dict_with_message_conf(self): + data = self.sample_data.copy() + data["message_conf"] = {"temperature": 0.9, "top_p": 0.8, "top_k": 5} + conversation = ConversationDO.from_dict(data) + self.assertEqual(conversation.message_conf.__temperature, 0.9) + self.assertEqual(conversation.message_conf.__top_p, 0.8) + self.assertEqual(conversation.message_conf.__top_k, 5) + + def test_update_with_message_conf(self): + conversation = ConversationDO(**self.sample_data) + updated_data = { + "message_conf": {"temperature": 0.9, "top_p": 0.8, "top_k": 5} + } + conversation.update(updated_data) + self.assertEqual(conversation.message_conf.__temperature, 0.9) + self.assertEqual(conversation.message_conf.__top_p, 0.8) + self.assertEqual(conversation.message_conf.__top_k, 5) diff --git a/tests/src/llm/__init__.py b/tests/src/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/src/llm/client/__init__.py b/tests/src/llm/client/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/src/utils/test_string_utils.py b/tests/src/utils/test_string_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4545a72ff8a783340c0707087ce7a2638c5c13e8 --- /dev/null +++ b/tests/src/utils/test_string_utils.py @@ -0,0 +1,68 @@ +import unittest +from src.utils import ( + is_image_link, + extract_urls, + extract_and_replace_url, + replace_latex_format +) + +class TestStringUtils(unittest.TestCase): + def setUp(self): + self.image_links = [ + 'http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png' + ] + + self.raw_inputs = [ + 'Ran code\nSTDOUT:\n0.04886713118710795', + 'Ran code\nSTDOUT:\n0.04886713118710795\nGenerated an image: http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png', + 'Ran code\nSTDOUT:\n0.04886713118710795\n Add here on purpose Generated an image: http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png', + 'Ran code\nSTDOUT:\n0.04886713118710795\nGenerated an image: http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png\nGenerated files on server: http://tosv.byted.org/obj/codegraph/sandbox/fake.csv', + ] + + self.format_outputs = [ + '```python\nRan code\nSTDOUT:\n0.04886713118710795```\n', + '```python\nRan code\nSTDOUT:\n0.04886713118710795\n```\n![The Image](http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png)', + '```python\nRan code\nSTDOUT:\n0.04886713118710795\n Add here on purpose ```\n![The Image](http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png)', + '```python\nRan code\nSTDOUT:\n0.04886713118710795\n\n```\n![The Image](http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png)\n[fake.csv](http://tosv.byted.org/obj/codegraph/sandbox/fake.csv)', + ] + + self.extracted_urls = [ + [], + ['http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png'], + ['http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png'], + ['http://tosv.byted.org/obj/codegraph/sandbox/temp/ws987dcaa4dba3400b91707f6825146f4e/1696533095597078758.png', 'http://tosv.byted.org/obj/codegraph/sandbox/fake.csv'] + ] + + self.raw_latex_input = [ + 'for people who earn $10,000 or $100,000 in each country:\n\n\\[\n\\begin{array}{|c|c|}\n\\hline\n\\textbf{native-country} & \\textbf{hours-per-week} \\\\\n\\hline\n?\\ & 45.55 \\\\\n Cambodia & 40.00 \\\\\n Canada & 45.64 \\\\\n... &... \\\\\n United-States & 45.51 \\\\\n Vietnam & 39.20 \\\\\n Yugoslavia & 49.50 \\\\\n\\hline\n\\end{array}\n\\]\n\n(Note: The country labeled "?" is likely a placeholder for countries with missing or unknown native-country values.)\n\nPlease note that this dataset contains non-binary values for the native-country column.' + ] + self.raw_latex_output = [ + 'for people who earn $10,000 or $100,000 in each country:\n\n$$\n\\begin{array}{|c|c|}\n\\hline\n\\textbf{native-country} & \\textbf{hours-per-week} \\\\\n\\hline\n?\\ & 45.55 \\\\\n Cambodia & 40.00 \\\\\n Canada & 45.64 \\\\\n... &... \\\\\n United-States & 45.51 \\\\\n Vietnam & 39.20 \\\\\n Yugoslavia & 49.50 \\\\\n\\hline\n\\end{array}\n$$\n\n(Note: The country labeled "?" is likely a placeholder for countries with missing or unknown native-country values.)\n\nPlease note that this dataset contains non-binary values for the native-country column.' + ] + + def test_is_image_link(self): + for link in self.image_links: + assert is_image_link(link) == True + + def test_extract_urls(self): + for input_text, expected_matched_urls in zip(self.raw_inputs, self.extracted_urls): + matched_urls = extract_urls(input_text) + assert len(matched_urls) == len(expected_matched_urls) + for url, expect_url in zip(matched_urls, expected_matched_urls): + assert url == expect_url + + def test_extract_and_replace(self): + for input_text, expected_output_text in zip(self.raw_inputs, self.format_outputs): + real_output = extract_and_replace_url(input_text) + if expected_output_text != real_output: + print(f'real_output: {real_output}\nexpected_output: {expected_output_text}') + assert expected_output_text == real_output + + def test_replace_latex_format(self): + for input_text, expected_output_text in zip(self.raw_latex_input, self.raw_latex_output): + real_output = replace_latex_format(input_text) + assert expected_output_text == real_output + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit_tests/test_file_utils.py b/tests/unit_tests/test_file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f62ac20ef4e5a7091a551a103e27863d7f729d --- /dev/null +++ b/tests/unit_tests/test_file_utils.py @@ -0,0 +1,102 @@ +import csv +import io +import os +import unittest + +from src.utils import upload_files, clear_files, convert_delimiter_to_comma + + +class UploadedFile: + def __init__(self, name, data, type): + self.name = name + self._data = data + self.type = type + self.size = len(data) + + def getvalue(self): + return self._data + + def read(self): + return self._data + + +class TestFileUtils(unittest.TestCase): + + def setUp(self): + # Setup code: create some sample files for testing + self.sample_csv_file_comma = UploadedFile(name='sample_comma.csv', + data=b'col1,col2,col3\n1,2,3\n4,5,6', type='text/csv') + self.sample_csv_file_semicolon = UploadedFile(name='sample_semicolon.csv', + data=b'col1;col2;col3\n1;2;3\n4;5;6', type='text/csv') + self.sample_txt_file = UploadedFile(name='sample.txt', data=b'Sample text file content', type='text/plain') + + def test_upload_non_csv_file(self): + # Test uploading a non-CSV file + result = upload_files([self.sample_txt_file]) + self.assertEqual(len(result), 1) + self.assertTrue(result[0].endswith('sample.txt')) + + def test_upload_csv_file_comma(self): + # Test uploading a CSV file with comma as delimiter + result = upload_files([self.sample_csv_file_comma]) + self.assertEqual(len(result), 1) + self.assertTrue(result[0].endswith('sample_comma.csv')) + + def test_upload_csv_file_semicolon(self): + # Test uploading a CSV file with semicolon as delimiter + result = upload_files([self.sample_csv_file_semicolon]) + self.assertEqual(len(result), 1) + self.assertTrue(result[0].endswith('sample_semicolon.csv')) + + # Check if the file has been converted to use comma as delimiter + with open(result[0], 'r', encoding='utf-8') as file: + reader = csv.reader(file) + rows = list(reader) + self.assertEqual(rows, [['col1', 'col2', 'col3'], ['1', '2', '3'], ['4', '5', '6']]) + + def test_clear_stored_files(self): + # Test clearing stored files + upload_files([self.sample_txt_file, self.sample_csv_file_comma]) + clear_files("./tmp/upload_files/") + self.assertEqual(len(os.listdir("./tmp/upload_files/")), 0) + + def test_comma_delimiter(self): + content = "a,b,c\n1,2,3\n" + + # Encode the string to bytes + byte_content = content.encode() + bytes_stream = io.BytesIO(byte_content) + + new_stream, converted = convert_delimiter_to_comma(bytes_stream) + self.assertEqual(converted, False) + self.assertEqual(new_stream.getvalue().decode(), content) + + def test_semicolon_delimiter(self): + content = "a;b;c\n1;2;3\n" + expected_content = "a,b,c\n1,2,3\n" + + # Encode the string to bytes + byte_content = content.encode() + bytes_stream = io.BytesIO(byte_content) + + new_stream, converted = convert_delimiter_to_comma(bytes_stream) + self.assertEqual(converted, True) + self.assertEqual(new_stream.getvalue().decode(), expected_content) + + def test_tab_delimiter(self): + content = "a\tb\tc\n1\t2\t3\n" + expected_content = "a,b,c\n1,2,3\n" + byte_content = content.encode() + bytes_stream = io.BytesIO(byte_content) + new_stream, converted = convert_delimiter_to_comma(bytes_stream) + self.assertEqual(converted, True) + string_content = new_stream.getvalue().decode() + self.assertEqual(string_content, expected_content) + + def tearDown(self): + # Cleanup code: remove any created files and directories + clear_files("./tmp/upload_files/") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit_tests/test_httpx.py b/tests/unit_tests/test_httpx.py new file mode 100644 index 0000000000000000000000000000000000000000..b62f11a85f8f7b1c1b6f71cb0ee102f4d16aa653 --- /dev/null +++ b/tests/unit_tests/test_httpx.py @@ -0,0 +1,20 @@ +import httpx +import asyncio +import aiohttp + +async def fetch(url: str): + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + return await response.text() + +async def test_httpx(): + url = 'https://www.example.org/' + async with httpx.AsyncClient() as client: + response = await client.get(url) + print(response) + return "" + +if __name__ == "__main__": + # asyncio.run(test_httpx()) + text = asyncio.run(fetch('http://www.example.com')) + print(text) \ No newline at end of file diff --git a/tests/unit_tests/test_python_code_sandbox.py b/tests/unit_tests/test_python_code_sandbox.py new file mode 100644 index 0000000000000000000000000000000000000000..3c478b1f72aa44f4aae1018d4a6d5807f87f0a95 --- /dev/null +++ b/tests/unit_tests/test_python_code_sandbox.py @@ -0,0 +1,184 @@ +from unittest.mock import MagicMock + +import pytest + +import src.schemas.sandbox_models as dm +from src.tools import AsyncPythonSandBoxTool, PythonSandBoxToolResponse + +# from src.tools.code_sandbox + + +# Mock SandboxClient +@pytest.fixture +def mock_client(): + client = MagicMock() + return client + + +# Setup PythonSandBoxTool with the mocked client +@pytest.fixture +def tool(mock_client): + tool = AsyncPythonSandBoxTool(name="Sandbox", description="Remote code sandbox") + tool.client = mock_client + return tool + + +def test_upload_file(tool, mock_client): + mock_client.upload_file.return_value = dm.UploadOutput(code=0, message="succeed", data="/mnt/file.txt") + response = tool.upload_file("/path/to/file.txt") + assert isinstance(response, PythonSandBoxToolResponse) + assert response.output_text == "Successfully uploaded file to server's path: /mnt/file.txt" + + +def test_download_file(tool, mock_client): + mock_client.download_file.return_value = dm.DownloadSuccessOutput(file_name="file.txt", content="hello world") + response = tool.download_file("file.txt") + assert isinstance(response, PythonSandBoxToolResponse) + assert response.output_text == "Successfully downloaded file: file.txt" + assert response.raw_output.content == "hello world" + +def test_extract_code(tool): + test_str = '''```python + import pandas as pd + + # Load the data + df = pd.read_csv('/mnt/0020400390.csv') + + # Checking the first few records + df.head() + ```''' + + test_str2 = '''``` + import pandas as pd + + # Load the data + df = pd.read_csv('/mnt/0020400390.csv') + + # Checking the first few records + df.head() + ```''' + + extracted_code = tool.input_handler(test_str) + assert extracted_code == '''import pandas as pd + + # Load the data + df = pd.read_csv('/mnt/0020400390.csv') + + # Checking the first few records + df.head()''' + + extracted_code2 = tool.input_handler(test_str2) + assert extracted_code2 == '''import pandas as pd + + # Load the data + df = pd.read_csv('/mnt/0020400390.csv') + + # Checking the first few records + df.head()''' + + +def test_run_code_success(tool, mock_client): + # code execute successful, stdout + code = "```python\nabc\n```" + mock_client.run_code.return_value = dm.RunCodeOutput( + code=0, + message="succeed", + data=dm.CodeRunData(is_partial=False, + result=dm.CodeRunResult( + code_output_result=[dm.CodeOutput(type="stdout", content="hello\n")], + deleted_files=[], + new_generated_files=[]))) + response = tool.run(code) + assert isinstance(response, PythonSandBoxToolResponse) + assert response.output_text == "Ran code\nSTDOUT:\nhello\n\n" + + # code execute successful, image + code = "```python\nabc\n```" + mock_client.run_code.return_value = dm.RunCodeOutput(code=0, + message="succeed", + data=dm.CodeRunData( + is_partial=False, + result=dm.CodeRunResult(code_output_result=[ + dm.CodeOutput(type="image", + content="https://example.com/image.png") + ], + deleted_files=[], + new_generated_files=[]))) + response = tool.run(code) + assert isinstance(response, PythonSandBoxToolResponse) + assert response.output_text == "Ran code\nGenerated an image: https://example.com/image.png\n" + + # code generated output file + code = "```python\nabc\n```" + mock_client.run_code.return_value = dm.RunCodeOutput( + code=0, + message="succeed", + data=dm.CodeRunData(is_partial=False, + result=dm.CodeRunResult(code_output_result=[], + deleted_files=[], + new_generated_files=["/mnt/qr.jpg", "/mnt/qr2.jpg"]))) + mock_client.download_file.return_value = dm.DownloadSuccessOutput(file_name="qr.jpg", content="hello world") + response = tool.run(code) + assert isinstance(response, PythonSandBoxToolResponse) + file_str = ",".join(response.raw_output.data.result.new_generated_files) + assert response.output_text == f"Ran code\nGenerated files on server: {file_str}\n" + + # code deleted files + code = "```python\nabc\n```" + mock_client.run_code.return_value = dm.RunCodeOutput( + code=0, + message="succeed", + data=dm.CodeRunData(is_partial=False, + result=dm.CodeRunResult(code_output_result=[], + deleted_files=["/mnt/qr.jpg", "/mnt/qr2.jpg"], + new_generated_files=[]))) + response = tool.run(code) + assert isinstance(response, PythonSandBoxToolResponse) + assert response.output_text == "Ran code\nDeleted files from server: /mnt/qr.jpg,/mnt/qr2.jpg\n" + + # everything + code = "```python\nabc\n```" + mock_client.run_code.return_value = dm.RunCodeOutput( + code=0, + message="succeed", + data=dm.CodeRunData(is_partial=False, + result=dm.CodeRunResult(code_output_result=[ + dm.CodeOutput(type="image", content="https://example.com/image.png"), + dm.CodeOutput(type="stdout", content="Plotted and saved files"), + dm.CodeOutput(type="stderr", content="Something is deprecated") + ], + deleted_files=["/mnt/qr.jpg", "/mnt/qr2.jpg"], + new_generated_files=["/mnt/plot.png", "/mnt/plot2.png"]))) + response = tool.run(code) + assert isinstance(response, PythonSandBoxToolResponse) + generated_file_str = ",".join(response.raw_output.data.result.new_generated_files) + expected_str = "Ran code\n" + \ + "Deleted files from server: /mnt/qr.jpg,/mnt/qr2.jpg\n" + \ + f"Generated files on server: {generated_file_str}\n" + \ + "Generated an image: https://example.com/image.png\n" + \ + "STDOUT:\n" + \ + "Plotted and saved files\n" + \ + "STDERR:\n" + \ + "Something is deprecated\n" + assert response.output_text == expected_str + + +def test_run_code_timeout(tool, mock_client): + server_return = { + "code": 0, + "data": { + "is_partial": True, + "result": { + "code_output_result": [], + "deleted_files": [], + "new_generated_files": [] + } + }, + "message": "the code doesn't finish in timeout value 3" + } + mock_client.run_code.return_value = dm.RunCodeOutput(**server_return) + response = tool.run("```python\nabc\n```") + assert isinstance(response, PythonSandBoxToolResponse) + expected_str = "Ran code but was not fully successful\n" + \ + "What happened: the code doesn\'t finish in timeout value 3\n" + assert response.output_text == expected_str