g3eIL commited on
Commit
77320e4
·
verified ·
1 Parent(s): 891350c

Upload 80 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +131 -0
  2. Dockerfile +3 -0
  3. activities/activity_helpers.py +33 -0
  4. activities/api.py +93 -0
  5. activities/complete_chat.py +77 -0
  6. activities/eval.py +207 -0
  7. activities/local_demo.py +108 -0
  8. activities/local_test.py +87 -0
  9. activities/predict.py +41 -0
  10. activities/vllm_api_server.py +636 -0
  11. configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml +23 -0
  12. configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml +23 -0
  13. configs/agent_configs/react_agent_azureopenai_gpt_4_async_dcoker.yaml +23 -0
  14. configs/agent_configs/react_agent_gpt4_async.yaml +23 -0
  15. configs/agent_configs/react_agent_llama_async.yaml +23 -0
  16. configs/agent_configs/react_agent_opt_async.yaml +23 -0
  17. configs/tool_configs/async_python_code_sandbox.yaml +7 -0
  18. configs/tool_configs/async_python_code_sandbox_docker.yaml +7 -0
  19. run.sh +3 -0
  20. run_demo.sh +5 -0
  21. run_local.sh +4 -0
  22. setup.py +40 -0
  23. src/infiagent/__init__.py +0 -0
  24. src/infiagent/agent/__init__.py +2 -0
  25. src/infiagent/agent/base_agent.py +337 -0
  26. src/infiagent/agent/react/__init__.py +4 -0
  27. src/infiagent/agent/react/async_react_agent.py +299 -0
  28. src/infiagent/conversation_sessions/__init__.py +1 -0
  29. src/infiagent/conversation_sessions/code_interpreter_session.py +87 -0
  30. src/infiagent/exceptions/__init__.py +0 -0
  31. src/infiagent/exceptions/exceptions.py +46 -0
  32. src/infiagent/llm/__init__.py +5 -0
  33. src/infiagent/llm/base_llm.py +36 -0
  34. src/infiagent/llm/client/__init__.py +0 -0
  35. src/infiagent/llm/client/azure_openai.py +346 -0
  36. src/infiagent/llm/client/llama.py +377 -0
  37. src/infiagent/llm/client/openai.py +306 -0
  38. src/infiagent/llm/client/opt.py +373 -0
  39. src/infiagent/prompt/__init__.py +3 -0
  40. src/infiagent/prompt/prompt_template.py +83 -0
  41. src/infiagent/prompt/simple_react_prompt.py +17 -0
  42. src/infiagent/prompt/zero_shot_react_prompt.py +36 -0
  43. src/infiagent/schemas/__init__.py +5 -0
  44. src/infiagent/schemas/agent_models.py +148 -0
  45. src/infiagent/schemas/base_models.py +0 -0
  46. src/infiagent/schemas/complete_models.py +236 -0
  47. src/infiagent/schemas/llm_models.py +91 -0
  48. src/infiagent/schemas/sandbox_models.py +69 -0
  49. src/infiagent/services/__init__.py +0 -0
  50. src/infiagent/services/chat_complete_service.py +196 -0
.gitignore ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ *.py,cover
49
+ .hypothesis/
50
+ .pytest_cache/
51
+
52
+ # Translations
53
+ *.mo
54
+ *.pot
55
+
56
+ # Django stuff:
57
+ *.log
58
+ local_settings.py
59
+ db.sqlite3
60
+ db.sqlite3-journal
61
+
62
+ # Flask stuff:
63
+ instance/
64
+ .webassets-cache
65
+
66
+ # Scrapy stuff:
67
+ .scrapy
68
+
69
+ # Sphinx documentation
70
+ docs/_build/
71
+ build/doctrees
72
+ build/html
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # pyenv
81
+ .python-version
82
+
83
+ # celery beat schedule file
84
+ celerybeat-schedule
85
+
86
+ # SageMath parsed files
87
+ *.sage.py
88
+
89
+ # Environments
90
+ .env
91
+ .venv
92
+ env/
93
+ venv/
94
+ ENV/
95
+ env.bak/
96
+ venv.bak/
97
+
98
+ # Spyder project settings
99
+ .spyderproject
100
+ .spyproject
101
+
102
+ # Rope project settings
103
+ .ropeproject
104
+
105
+ # mkdocs documentation
106
+ /site
107
+
108
+ # mypy
109
+ .mypy_cache/
110
+ .dmypy.json
111
+ dmypy.json
112
+
113
+ # Pyre type checker
114
+ .pyre/
115
+
116
+ # pytype static type analyzer
117
+ .pytype/
118
+
119
+ # Cython debug symbols
120
+ cython_debug/
121
+
122
+ # JetBrains PyCharm specific
123
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, GoLand, Rider and Android Studio
124
+ .idea/
125
+ *.iml
126
+
127
+ # User-specific stuff
128
+ *.swp
129
+ *~
130
+ .Session.vim
131
+ /.sass-cache
Dockerfile ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ FROM python:3
2
+
3
+ RUN pip install pandas numpy scikit-learn matplotlib seaborn
activities/activity_helpers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from sse_starlette import ServerSentEvent
4
+
5
+ from infiagent.schemas import ResponseBaseData
6
+
7
+
8
+ IGNORE_PING_COMMENT = {"comment": "IGNORE PING"}
9
+ DONE = "[DONE]"
10
+
11
+
12
+ async def async_sse_response_format(response_data_gen):
13
+ async for content in response_data_gen:
14
+ if content == DONE:
15
+ sse_event = ServerSentEvent(data=DONE)
16
+ else:
17
+ data_dict = {
18
+ "response": content,
19
+ "ResponseBase": ResponseBaseData().dict()
20
+ }
21
+ sse_event = ServerSentEvent(data=json.dumps(data_dict, ensure_ascii=False))
22
+ yield sse_event
23
+
24
+
25
+ def json_response_format(content):
26
+ return {
27
+ "response": content,
28
+ "ResponseBase": ResponseBaseData().dict()
29
+ }
30
+
31
+
32
+ def get_ignore_ping_comment():
33
+ return lambda: ServerSentEvent(**IGNORE_PING_COMMENT)
activities/api.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import uuid
3
+
4
+ import uvloop
5
+ from dotenv import load_dotenv
6
+ from fastapi import FastAPI, HTTPException, Request
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from sse_starlette.sse import EventSourceResponse, ServerSentEvent
9
+ from starlette.responses import JSONResponse, Response
10
+
11
+ from .activity_helpers import DONE
12
+ from .complete_chat import complete_chat_router
13
+ from .predict import predict_router
14
+
15
+ try:
16
+ import infiagent
17
+ from infiagent.schemas import FailedResponseBaseData
18
+ from infiagent.utils import get_logger, init_logging, log_id_var
19
+ except ImportError:
20
+ print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
21
+ from ..schemas import FailedResponseBaseData
22
+ from ..utils import get_logger, init_logging, log_id_var
23
+
24
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
25
+
26
+ SSE_API_PATHS = ["/complete_sse"]
27
+ LOG_ID_HEADER_NAME = "X-Tt-Logid"
28
+
29
+
30
+ load_dotenv()
31
+ init_logging()
32
+ logger = get_logger()
33
+
34
+ app = FastAPI()
35
+ origins = ["*"]
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=origins,
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+ app.include_router(complete_chat_router)
44
+ app.include_router(predict_router)
45
+
46
+
47
+ @app.middleware("http")
48
+ async def log_id_middleware(request: Request, call_next):
49
+ # Get X-Tt-Logid from request headers
50
+ log_id = request.headers.get(LOG_ID_HEADER_NAME)
51
+ if not log_id:
52
+ # Generate a log_id if not present in headers
53
+ log_id = str(uuid.uuid4())
54
+
55
+ log_id_var.set(log_id)
56
+
57
+ response: Response = await call_next(request)
58
+ response.headers[LOG_ID_HEADER_NAME] = log_id_var.get()
59
+ return response
60
+
61
+
62
+ @app.exception_handler(Exception)
63
+ async def general_exception_handler(request, exc):
64
+ error_msg = "Failed to handle request. Internal Server error: {}".format(str(exc))
65
+ logger.error(error_msg, exc_info=True)
66
+
67
+ if request.url.path in SSE_API_PATHS:
68
+ return EventSourceResponse(ServerSentEvent(data=DONE))
69
+ else:
70
+ return JSONResponse(
71
+ status_code=500,
72
+ content={
73
+ "response": error_msg,
74
+ "ResponseBase": FailedResponseBaseData().dict()
75
+ }
76
+ )
77
+
78
+
79
+ @app.exception_handler(HTTPException)
80
+ async def http_exception_handler(request, exc):
81
+ error_msg = "Failed to handle request. Error: {}".format(exc.detail)
82
+ logger.error(error_msg, exc_info=True)
83
+
84
+ if request.url.path in SSE_API_PATHS:
85
+ return EventSourceResponse(ServerSentEvent(data=DONE))
86
+ else:
87
+ return JSONResponse(
88
+ status_code=exc.status_code,
89
+ content={
90
+ "response": error_msg,
91
+ "ResponseBase": FailedResponseBaseData().dict()
92
+ }
93
+ )
activities/complete_chat.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Request, HTTPException
2
+ from pydantic import ValidationError
3
+ from sse_starlette import EventSourceResponse, ServerSentEvent
4
+
5
+ from .activity_helpers import async_sse_response_format, IGNORE_PING_COMMENT, json_response_format
6
+
7
+ try:
8
+ import infiagent
9
+ from infiagent.db.conversation_dao import ConversationDAO
10
+ from infiagent.schemas import ChatCompleteRequest
11
+ from infiagent.services.chat_complete_sse_service import chat_event_generator, chat_event_response
12
+ from infiagent.tools.code_sandbox.async_sandbox_client import AsyncSandboxClient
13
+ from infiagent.utils import get_logger
14
+ except ImportError:
15
+ print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
16
+ from ..db.conversation_dao import ConversationDAO
17
+ from ..schemas import ChatCompleteRequest
18
+ from ..services.chat_complete_sse_service import chat_event_generator, chat_event_response
19
+ from ..tools.code_sandbox.async_sandbox_client import AsyncSandboxClient
20
+ from ..utils import get_logger
21
+
22
+ complete_chat_router = APIRouter()
23
+ logger = get_logger()
24
+
25
+
26
+ @complete_chat_router.post("/complete_sse")
27
+ async def complete_sse(request: Request):
28
+ body_str = await request.body()
29
+
30
+ try:
31
+ chat_request = ChatCompleteRequest.parse_raw(body_str)
32
+ logger.info("Got chat request: {}".format(chat_request))
33
+ except ValidationError as e:
34
+ error_msg = "Invalid input chat_request. Error: {}".format(str(e))
35
+ raise HTTPException(status_code=400, detail=error_msg)
36
+
37
+ return EventSourceResponse(async_sse_response_format(chat_event_generator(chat_request)),
38
+ ping_message_factory=lambda: ServerSentEvent(**IGNORE_PING_COMMENT))
39
+
40
+
41
+ @complete_chat_router.post("/complete")
42
+ async def complete(request: Request):
43
+ body_str = await request.body()
44
+
45
+ try:
46
+ chat_request = ChatCompleteRequest.parse_raw(body_str)
47
+ logger.info("Got chat request: {}".format(chat_request))
48
+ except ValidationError as e:
49
+ error_msg = "Invalid input chat_request. Error: {}".format(str(e))
50
+ raise HTTPException(status_code=400, detail=error_msg)
51
+
52
+ response_items = await chat_event_response(chat_request)
53
+
54
+ return json_response_format(response_items)
55
+
56
+
57
+ @complete_chat_router.get("/heartbeat")
58
+ async def heartbeat(chat_id: str = None, session_id: str = None):
59
+ if not chat_id and not session_id:
60
+ raise HTTPException(status_code=400, detail="Either chat_id or session_id must be provided.")
61
+
62
+ input_chat_id = chat_id or session_id
63
+
64
+ conversation = await ConversationDAO.get_conversation(input_chat_id)
65
+ if not conversation:
66
+ logger.info(f'Call heartbeat on a non-exist conversion, {input_chat_id}')
67
+ return json_response_format("conversation is not created, skip")
68
+
69
+ if conversation.sandbox_id is None:
70
+ logger.error(f'No sandbox id for heartbeat, chat id {input_chat_id}')
71
+ raise HTTPException(status_code=404, detail=f'No sandbox id for heartbeat, chat id {input_chat_id}')
72
+
73
+ # TODO Add exception handling logic here for heartbeat failed in sandbox side
74
+ heartbeat_response = await AsyncSandboxClient(conversation.sandbox_id).heartbeat()
75
+ logger.info(f"Heartbeat response {heartbeat_response}")
76
+
77
+ return json_response_format("succeed")
activities/eval.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+ import asyncio
5
+ import logging
6
+ import sys
7
+ import json
8
+ import io
9
+
10
+ import openai
11
+
12
+
13
+ import infiagent
14
+ from infiagent.utils import get_logger, upload_files, get_file_name_and_path
15
+ from infiagent.services.chat_complete_service import predict
16
+
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ class UploadedFile(io.BytesIO):
22
+ def __init__(self, path):
23
+ with open(path, 'rb') as file:
24
+ data = file.read()
25
+
26
+ super().__init__(data)
27
+
28
+ self.name = path.split("/")[-1] # 获取文件名
29
+ self.type = 'application/octet-stream' # 或者其他适当的 MIME 类型
30
+ self.size = len(data)
31
+
32
+ def __repr__(self):
33
+ return f"MyUploadedFile(name={self.name}, size={self.size}, type={self.type})"
34
+
35
+ def __len__(self):
36
+
37
+ return self.size
38
+
39
+ # # 使用例子
40
+ # file_path = "path/to/your/file"
41
+ # uploaded_file = MyUploadedFile(file_path)
42
+
43
+ # print(uploaded_file)
44
+
45
+
46
+ def _get_script_params():
47
+ try:
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument('--llm',
50
+ help='LLM Model for demo',
51
+ required=False, type=str)
52
+ parser.add_argument('--api_key',
53
+ help='Open API token key.',
54
+ required=False, type=str)
55
+
56
+ parser.add_argument('--config_path',
57
+ help='Config path for demo',
58
+ default="configs/agent_configs/react_agent_llama_async.yaml",
59
+ required=False, type=str)
60
+
61
+ args = parser.parse_args()
62
+
63
+ return args
64
+ except Exception as e:
65
+ logger.error("Failed to get script input arguments: {}".format(str(e)), exc_info=True)
66
+
67
+ return None
68
+
69
+
70
+ def extract_questions_and_concepts(file_path):
71
+ # Read the content of the text file
72
+ with open(file_path, 'r') as file:
73
+ content = file.read()
74
+
75
+ # Use regular expressions to extract questions and concepts
76
+ pattern = r'\\Question{(.*?)}\s*\\Concepts{(.*?)}'
77
+ matches = re.findall(pattern, content, re.DOTALL)
78
+
79
+ # Build a list of dictionaries containing the questions and concepts
80
+ data = []
81
+ for match in matches:
82
+ question = match[0].strip()
83
+ concepts = [concept.strip() for concept in match[1].split(',')]
84
+ data.append({
85
+ 'question': question,
86
+ 'concepts': concepts
87
+ })
88
+
89
+ return data
90
+
91
+ def read_dicts_from_file(file_name):
92
+ """
93
+ Read a file with each line containing a JSON string representing a dictionary,
94
+ and return a list of dictionaries.
95
+
96
+ :param file_name: Name of the file to read from.
97
+ :return: List of dictionaries.
98
+ """
99
+ dict_list = []
100
+ with open(file_name, 'r') as file:
101
+ for line in file:
102
+ # Convert the JSON string back to a dictionary.
103
+ dictionary = json.loads(line.rstrip('\n'))
104
+ dict_list.append(dictionary)
105
+ return dict_list
106
+
107
+ def read_questions(file_path):
108
+ print(file_path)
109
+ with open(file_path) as f:
110
+ questions = json.load(f)
111
+
112
+ return questions
113
+
114
+ def extract_data_from_folder(folder_path):
115
+
116
+ print(f'folder_path {folder_path}')
117
+ extracted_data = {}
118
+ # Traverse the files in the folder
119
+ for file_name in os.listdir(folder_path):
120
+ if file_name.endswith('.questions'): # You can filter files based on their type
121
+ file_path = os.path.join(folder_path, file_name)
122
+ file_data = read_questions(file_path)
123
+ file_name_without_extension = os.path.splitext(file_name)[0]
124
+ extracted_data[file_name_without_extension] = file_data
125
+
126
+ return extracted_data
127
+
128
+
129
+ async def main():
130
+ extracted_data = read_dicts_from_file('./data/da-dev-questions.jsonl')
131
+ args = _get_script_params()
132
+
133
+ model_name = getattr(args, "llm", None)
134
+ open_ai_key = getattr(args, "api_key", None)
135
+
136
+ if "OPEN_AI" in model_name:
137
+ logger.info("setup open ai ")
138
+ if os.environ.get("OPENAI_API_KEY") is None:
139
+ if open_ai_key:
140
+ openai.api_key = open_ai_key
141
+ os.environ["OPENAI_API_KEY"] = open_ai_key
142
+ else:
143
+ raise ValueError("OPENAI_API_KEY is None, please provide open ai key to use open ai model. Adding "
144
+ "'--api_key' to set it up")
145
+
146
+ # 获取 'openai' 的 logger
147
+ openai_logger = logging.getLogger('openai')
148
+ # 设置日志级别为 'WARNING',这样 'INFO' 级别的日志就不会被打印了
149
+ openai_logger.setLevel(logging.WARNING)
150
+ else:
151
+ logger.info("use local model ")
152
+
153
+ table_path = 'data/da-dev-tables'
154
+ results = []
155
+
156
+ i = 1
157
+ for q in extracted_data:
158
+ input_text = q['question']
159
+ concepts = q['concepts']
160
+ file_path = q['file_name']
161
+ constraints = q['constraints']
162
+ format = q['format']
163
+
164
+ file_path = os.path.join(table_path, file_path)
165
+
166
+ print(f'input_text: {input_text}')
167
+ print(f'concepts: {concepts}')
168
+ print(f'file_path: {file_path}')
169
+
170
+ uploaded_file = UploadedFile(file_path)
171
+ print(uploaded_file)
172
+
173
+ prompt = f"Question: {input_text}\n{constraints}\n"
174
+
175
+ response = await predict(
176
+ prompt=prompt,
177
+ model_name=model_name,
178
+ config_path=args.config_path,
179
+ uploaded_files=[uploaded_file]
180
+ )
181
+
182
+ iteration_result = {
183
+ 'id': q['id'],
184
+ 'input_text': prompt,
185
+ 'concepts': concepts,
186
+ 'file_path': file_path,
187
+ 'response': response,
188
+ 'format': format
189
+ }
190
+ results.append(iteration_result)
191
+ print(f"response: {response}")
192
+
193
+ if i % 10 == 0:
194
+ with open('results_{}.json'.format(model_name), 'w') as outfile:
195
+ json.dump(results, outfile, indent=4)
196
+
197
+ i += 1
198
+
199
+ with open('results_{}.json'.format(model_name), 'w') as outfile:
200
+ json.dump(results, outfile, indent=4)
201
+
202
+
203
+ if __name__ == '__main__':
204
+ asyncio.run(main())
205
+ # main()
206
+
207
+
activities/local_demo.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ import streamlit as st # type: ignore
8
+ import uvloop
9
+ import openai
10
+
11
+ try:
12
+ import infiagent
13
+ from infiagent.utils import get_logger, upload_files
14
+ from infiagent.services.chat_complete_service import predict
15
+ except ImportError:
16
+ raise (
17
+ "import infiagent failed, please install infiagent by 'pip install -e .' in the pipeline directory of ADA-Agent")
18
+
19
+ logger = get_logger()
20
+
21
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
22
+
23
+
24
+ def _get_script_params():
25
+ try:
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument('--llm',
28
+ help='LLM Model for demo',
29
+ required=False, type=str)
30
+ parser.add_argument('--api_key',
31
+ help='Open API token key.',
32
+ required=False, type=str)
33
+ parser.add_argument('--config_path',
34
+ help='Config path for demo',
35
+ # default="configs/agent_configs/react_agent_gpt4_async.yaml",
36
+ required=False, type=str)
37
+
38
+ args = parser.parse_args()
39
+
40
+ return args
41
+ except Exception as e:
42
+ logger.error("Failed to get script input arguments: {}".format(str(e)), exc_info=True)
43
+
44
+ return None
45
+
46
+
47
+ async def main():
48
+ args = _get_script_params()
49
+
50
+ model_name = getattr(args, "llm", None)
51
+ open_ai_key = getattr(args, "api_key", None)
52
+ config_path = getattr(args, "config_path", None)
53
+
54
+ if "OPEN_AI" in model_name:
55
+ logger.info("setup open ai ")
56
+ if os.environ.get("OPENAI_API_KEY") is None:
57
+ if open_ai_key:
58
+ openai.api_key = open_ai_key
59
+ os.environ["OPENAI_API_KEY"] = open_ai_key
60
+ else:
61
+ raise ValueError(
62
+ "OPENAI_API_KEY is None, please provide opekn ai key to use open ai model. Adding '--api_key' to set it up")
63
+
64
+ # 获取 'openai' 的 logger
65
+ openai_logger = logging.getLogger('openai')
66
+ # 设置日志级别为 'WARNING',这样 'INFO' 级别的日志就不会被打印了
67
+ openai_logger.setLevel(logging.WARNING)
68
+
69
+ else:
70
+ logger.info("use local model ")
71
+
72
+ st.set_page_config(layout="centered")
73
+
74
+ st.title("InfiAgent Code Interpreter Demo 🚀")
75
+
76
+ # Initialize session state variables if not already present
77
+ if 'chat_history' not in st.session_state:
78
+ st.session_state.chat_history = []
79
+
80
+ # UI components
81
+ input_text = st.text_area("Write your prompt")
82
+ uploaded_files = st.file_uploader("Upload your files", accept_multiple_files=True)
83
+ button_pressed = st.button("Run code interpreter", use_container_width=True)
84
+
85
+ # When button is pressed
86
+ if button_pressed and input_text != "":
87
+ # Add user message to chat history
88
+ st.session_state.chat_history.append({"role": "user", "message": input_text})
89
+
90
+ # Predict response (assuming you have the necessary async handling)
91
+ response = await predict(
92
+ prompt=input_text,
93
+ model_name=model_name,
94
+ config_path=config_path,
95
+ uploaded_files=uploaded_files,
96
+ )
97
+
98
+ # Add assistant message to chat history
99
+ st.session_state.chat_history.append({"role": "assistant", "message": response})
100
+
101
+ # Display chat history
102
+ for chat in st.session_state.chat_history:
103
+ with st.chat_message(chat["role"]):
104
+ st.write(chat["message"])
105
+
106
+
107
+ if __name__ == "__main__":
108
+ asyncio.run(main())
activities/local_test.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from fastapi import FastAPI, HTTPException, Request
4
+ from pydantic import ValidationError
5
+ from sse_starlette import EventSourceResponse
6
+
7
+ from .activity_helpers import (
8
+ async_sse_response_format,
9
+ get_ignore_ping_comment,
10
+ json_response_format,
11
+ )
12
+
13
+
14
+ try:
15
+ import infiagent
16
+ from infiagent.schemas import ChatCompleteRequest
17
+ from infiagent.services.complete_local_test import (
18
+ chat_local_event,
19
+ chat_local_event_generator,
20
+ )
21
+ from infiagent.utils import get_logger
22
+ except ImportError:
23
+ print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
24
+ from ..schemas import ChatCompleteRequest
25
+ from ..services.complete_local_test import (
26
+ chat_local_event,
27
+ chat_local_event_generator,
28
+ )
29
+ from ..utils import get_logger
30
+
31
+ logger = get_logger()
32
+ local_app = FastAPI()
33
+
34
+
35
+ @local_app.post("/local_sse_test")
36
+ async def complete_sse(request: Request):
37
+ body_str = await request.body()
38
+
39
+ try:
40
+ chat_request = ChatCompleteRequest.parse_raw(body_str)
41
+ logger.info("Got chat request: {}".format(chat_request))
42
+ except ValidationError as e:
43
+ error_msg = "Invalid input chat_request. Error: {}".format(str(e))
44
+ raise HTTPException(status_code=500, detail=error_msg)
45
+
46
+ return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)),
47
+ ping_message_factory=get_ignore_ping_comment())
48
+
49
+
50
+ @local_app.post("/local_json_test")
51
+ async def complete_json(request: Request):
52
+
53
+ body_str = await request.body()
54
+
55
+ try:
56
+ chat_request = ChatCompleteRequest.parse_raw(body_str)
57
+ logger.info("Got chat request: {}".format(chat_request))
58
+ except ValidationError as e:
59
+ error_msg = "Invalid input chat_request. Error: {}".format(str(e))
60
+ raise HTTPException(status_code=500, detail=error_msg)
61
+
62
+ response_items = await chat_local_event(chat_request)
63
+ return json_response_format(response_items)
64
+
65
+
66
+ @local_app.post("/exception_test")
67
+ async def complete_json(request: Request):
68
+ body_str = await request.body()
69
+
70
+ try:
71
+ chat_request = ChatCompleteRequest.parse_raw(body_str)
72
+ logger.info("Got chat request: {}".format(chat_request))
73
+ except ValidationError as e:
74
+ error_msg = "Invalid input chat_request. Error: {}".format(str(e))
75
+ raise HTTPException(status_code=500, detail=error_msg)
76
+ return EventSourceResponse(async_sse_response_format(chat_local_event_generator(chat_request)))
77
+
78
+
79
+ async def exception_test(request: Request):
80
+ body_str = await request.body()
81
+ json_val = json.loads(body_str)
82
+ exception_type = json_val.get("exception", None)
83
+
84
+ if exception_type:
85
+ raise ValueError("Error triggerd!")
86
+ else:
87
+ yield iter(["Success"])
activities/predict.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, File, Form, UploadFile
2
+ from typing import List, Optional
3
+
4
+ try:
5
+ import infiagent
6
+ from infiagent.services.chat_complete_service import predict
7
+ except ImportError:
8
+ print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent")
9
+ from ..services.chat_complete_service import predict
10
+
11
+ predict_router = APIRouter()
12
+
13
+
14
+ @predict_router.post("/predict")
15
+ async def chat_predict(
16
+ prompt: str = Form(...),
17
+ model_name: str = Form(...),
18
+ psm: Optional[str] = Form(None),
19
+ dc: Optional[str] = Form(None),
20
+ temperature: Optional[str] = Form(None),
21
+ top_p: Optional[str] = Form(None),
22
+ top_k: Optional[str] = Form(None),
23
+ files: List[UploadFile] = File(...)
24
+ ):
25
+ kwargs = {}
26
+ if psm:
27
+ kwargs['psm'] = psm
28
+ if dc:
29
+ kwargs['dc'] = dc
30
+ if temperature:
31
+ kwargs['temperature'] = float(temperature)
32
+ if top_p:
33
+ kwargs['top_p'] = float(top_p)
34
+ if top_k:
35
+ kwargs['top_k'] = float(top_k)
36
+
37
+ response = await predict(prompt, model_name, files, **kwargs)
38
+
39
+ return {
40
+ "answer": response
41
+ }
activities/vllm_api_server.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from
2
+ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
3
+
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ from http import HTTPStatus
9
+ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
10
+
11
+ import fastapi
12
+ import uvicorn
13
+ from fastapi import Request
14
+ from fastapi.exceptions import RequestValidationError
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.responses import JSONResponse, StreamingResponse, Response
17
+ from packaging import version
18
+
19
+ from vllm.engine.arg_utils import AsyncEngineArgs
20
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
21
+ from vllm.entrypoints.openai.protocol import (
22
+ CompletionRequest, CompletionResponse, CompletionResponseChoice,
23
+ CompletionResponseStreamChoice, CompletionStreamResponse,
24
+ ChatCompletionRequest, ChatCompletionResponse,
25
+ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
26
+ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
27
+ LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
28
+ from vllm.logger import init_logger
29
+ from vllm.outputs import RequestOutput
30
+ from vllm.sampling_params import SamplingParams
31
+ from vllm.transformers_utils.tokenizer import get_tokenizer
32
+ from vllm.utils import random_uuid
33
+
34
+ try:
35
+ import fastchat
36
+ from fastchat.conversation import Conversation, SeparatorStyle
37
+ from fastchat.model.model_adapter import get_conversation_template
38
+ _fastchat_available = True
39
+ except ImportError:
40
+ _fastchat_available = False
41
+
42
+ TIMEOUT_KEEP_ALIVE = 5 # seconds
43
+
44
+ logger = init_logger(__name__)
45
+ served_model = None
46
+ app = fastapi.FastAPI()
47
+ engine = None
48
+
49
+
50
+ def create_error_response(status_code: HTTPStatus,
51
+ message: str) -> JSONResponse:
52
+ return JSONResponse(ErrorResponse(message=message,
53
+ type="invalid_request_error").dict(),
54
+ status_code=status_code.value)
55
+
56
+
57
+ @app.exception_handler(RequestValidationError)
58
+ async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
59
+ return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
60
+
61
+
62
+ async def check_model(request) -> Optional[JSONResponse]:
63
+ if request.model == served_model:
64
+ return
65
+ ret = create_error_response(
66
+ HTTPStatus.NOT_FOUND,
67
+ f"The model `{request.model}` does not exist.",
68
+ )
69
+ return ret
70
+
71
+
72
+ async def get_gen_prompt(request) -> str:
73
+ if not _fastchat_available:
74
+ raise ModuleNotFoundError(
75
+ "fastchat is not installed. Please install fastchat to use "
76
+ "the chat completion and conversation APIs: `$ pip install fschat`"
77
+ )
78
+ if version.parse(fastchat.__version__) < version.parse("0.2.23"):
79
+ raise ImportError(
80
+ f"fastchat version is low. Current version: {fastchat.__version__} "
81
+ "Please upgrade fastchat to use: `$ pip install -U fschat`")
82
+
83
+ conv = get_conversation_template(request.model)
84
+ conv = Conversation(
85
+ name=conv.name,
86
+ system_template=conv.system_template,
87
+ system_message=conv.system_message,
88
+ roles=conv.roles,
89
+ messages=list(conv.messages), # prevent in-place modification
90
+ offset=conv.offset,
91
+ sep_style=SeparatorStyle(conv.sep_style),
92
+ sep=conv.sep,
93
+ sep2=conv.sep2,
94
+ stop_str=conv.stop_str,
95
+ stop_token_ids=conv.stop_token_ids,
96
+ )
97
+
98
+ if isinstance(request.messages, str):
99
+ prompt = request.messages
100
+ else:
101
+ for message in request.messages:
102
+ msg_role = message["role"]
103
+ if msg_role == "system":
104
+ conv.system_message = message["content"]
105
+ elif msg_role == "user":
106
+ conv.append_message(conv.roles[0], message["content"])
107
+ elif msg_role == "assistant":
108
+ conv.append_message(conv.roles[1], message["content"])
109
+ else:
110
+ raise ValueError(f"Unknown role: {msg_role}")
111
+
112
+ # Add a blank message for the assistant.
113
+ conv.append_message(conv.roles[1], None)
114
+ prompt = conv.get_prompt()
115
+
116
+ return prompt
117
+
118
+
119
+ async def check_length(
120
+ request: Union[ChatCompletionRequest, CompletionRequest],
121
+ prompt: Optional[str] = None,
122
+ prompt_ids: Optional[List[int]] = None
123
+ ) -> Tuple[List[int], Optional[JSONResponse]]:
124
+ assert (not (prompt is None and prompt_ids is None)
125
+ and not (prompt is not None and prompt_ids is not None)
126
+ ), "Either prompt or prompt_ids should be provided."
127
+ if prompt_ids is not None:
128
+ input_ids = prompt_ids
129
+ else:
130
+ input_ids = tokenizer(prompt).input_ids
131
+ token_num = len(input_ids)
132
+
133
+ if request.max_tokens is None:
134
+ request.max_tokens = max_model_len - token_num
135
+ if token_num + request.max_tokens > max_model_len:
136
+ return input_ids, create_error_response(
137
+ HTTPStatus.BAD_REQUEST,
138
+ f"This model's maximum context length is {max_model_len} tokens. "
139
+ f"However, you requested {request.max_tokens + token_num} tokens "
140
+ f"({token_num} in the messages, "
141
+ f"{request.max_tokens} in the completion). "
142
+ f"Please reduce the length of the messages or completion.",
143
+ )
144
+ else:
145
+ return input_ids, None
146
+
147
+
148
+ @app.get("/health")
149
+ async def health() -> Response:
150
+ """Health check."""
151
+ return Response(status_code=200)
152
+
153
+
154
+ @app.get("/v1/models")
155
+ async def show_available_models():
156
+ """Show available models. Right now we only have one model."""
157
+ model_cards = [
158
+ ModelCard(id=served_model,
159
+ root=served_model,
160
+ permission=[ModelPermission()])
161
+ ]
162
+ return ModelList(data=model_cards)
163
+
164
+
165
+ def create_logprobs(token_ids: List[int],
166
+ id_logprobs: List[Dict[int, float]],
167
+ initial_text_offset: int = 0) -> LogProbs:
168
+ """Create OpenAI-style logprobs."""
169
+ logprobs = LogProbs()
170
+ last_token_len = 0
171
+ for token_id, id_logprob in zip(token_ids, id_logprobs):
172
+ token = tokenizer.convert_ids_to_tokens(token_id)
173
+ logprobs.tokens.append(token)
174
+ logprobs.token_logprobs.append(id_logprob[token_id])
175
+ if len(logprobs.text_offset) == 0:
176
+ logprobs.text_offset.append(initial_text_offset)
177
+ else:
178
+ logprobs.text_offset.append(logprobs.text_offset[-1] +
179
+ last_token_len)
180
+ last_token_len = len(token)
181
+
182
+ logprobs.top_logprobs.append({
183
+ tokenizer.convert_ids_to_tokens(i): p
184
+ for i, p in id_logprob.items()
185
+ })
186
+ return logprobs
187
+
188
+
189
+ @app.post("/v1/chat/completions")
190
+ async def create_chat_completion(request: ChatCompletionRequest,
191
+ raw_request: Request):
192
+ """Completion API similar to OpenAI's API.
193
+
194
+ See https://platform.openai.com/docs/api-reference/chat/create
195
+ for the API specification. This API mimics the OpenAI ChatCompletion API.
196
+
197
+ NOTE: Currently we do not support the following features:
198
+ - function_call (Users should implement this by themselves)
199
+ - logit_bias (to be supported by vLLM engine)
200
+ """
201
+ logger.info(f"Received chat completion request: {request}")
202
+
203
+ error_check_ret = await check_model(request)
204
+ if error_check_ret is not None:
205
+ return error_check_ret
206
+
207
+ if request.logit_bias is not None and len(request.logit_bias) > 0:
208
+ # TODO: support logit_bias in vLLM engine.
209
+ return create_error_response(HTTPStatus.BAD_REQUEST,
210
+ "logit_bias is not currently supported")
211
+
212
+ prompt = await get_gen_prompt(request)
213
+ token_ids, error_check_ret = await check_length(request, prompt=prompt)
214
+ if error_check_ret is not None:
215
+ return error_check_ret
216
+
217
+ model_name = request.model
218
+ request_id = f"cmpl-{random_uuid()}"
219
+ created_time = int(time.monotonic())
220
+ try:
221
+ # spaces_between_special_tokens = request.spaces_between_special_tokens
222
+ sampling_params = SamplingParams(
223
+ n=request.n,
224
+ presence_penalty=request.presence_penalty,
225
+ frequency_penalty=request.frequency_penalty,
226
+ temperature=request.temperature,
227
+ top_p=request.top_p,
228
+ stop=request.stop,
229
+ stop_token_ids=request.stop_token_ids,
230
+ max_tokens=request.max_tokens,
231
+ best_of=request.best_of,
232
+ top_k=request.top_k,
233
+ ignore_eos=request.ignore_eos,
234
+ use_beam_search=request.use_beam_search,
235
+ skip_special_tokens=request.skip_special_tokens,
236
+ # spaces_between_special_tokens=spaces_between_special_tokens,
237
+ )
238
+ except ValueError as e:
239
+ return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
240
+
241
+ result_generator = engine.generate(prompt, sampling_params, request_id,
242
+ token_ids)
243
+
244
+ def create_stream_response_json(
245
+ index: int,
246
+ text: str,
247
+ finish_reason: Optional[str] = None,
248
+ ) -> str:
249
+ choice_data = ChatCompletionResponseStreamChoice(
250
+ index=index,
251
+ delta=DeltaMessage(content=text),
252
+ finish_reason=finish_reason,
253
+ )
254
+ response = ChatCompletionStreamResponse(
255
+ id=request_id,
256
+ created=created_time,
257
+ model=model_name,
258
+ choices=[choice_data],
259
+ )
260
+ response_json = response.json(ensure_ascii=False)
261
+
262
+ return response_json
263
+
264
+ async def completion_stream_generator() -> AsyncGenerator[str, None]:
265
+ # First chunk with role
266
+ for i in range(request.n):
267
+ choice_data = ChatCompletionResponseStreamChoice(
268
+ index=i,
269
+ delta=DeltaMessage(role="assistant"),
270
+ finish_reason=None,
271
+ )
272
+ chunk = ChatCompletionStreamResponse(id=request_id,
273
+ choices=[choice_data],
274
+ model=model_name)
275
+ data = chunk.json(exclude_unset=True, ensure_ascii=False)
276
+ yield f"data: {data}\n\n"
277
+
278
+ previous_texts = [""] * request.n
279
+ previous_num_tokens = [0] * request.n
280
+ async for res in result_generator:
281
+ res: RequestOutput
282
+ for output in res.outputs:
283
+ i = output.index
284
+ delta_text = output.text[len(previous_texts[i]):]
285
+ previous_texts[i] = output.text
286
+ previous_num_tokens[i] = len(output.token_ids)
287
+ response_json = create_stream_response_json(
288
+ index=i,
289
+ text=delta_text,
290
+ )
291
+ yield f"data: {response_json}\n\n"
292
+ if output.finish_reason is not None:
293
+ response_json = create_stream_response_json(
294
+ index=i,
295
+ text="",
296
+ finish_reason=output.finish_reason,
297
+ )
298
+ yield f"data: {response_json}\n\n"
299
+ yield "data: [DONE]\n\n"
300
+
301
+ # Streaming response
302
+ if request.stream:
303
+ return StreamingResponse(completion_stream_generator(),
304
+ media_type="text/event-stream")
305
+
306
+ # Non-streaming response
307
+ final_res: RequestOutput = None
308
+ async for res in result_generator:
309
+ if await raw_request.is_disconnected():
310
+ # Abort the request if the client disconnects.
311
+ await engine.abort(request_id)
312
+ return create_error_response(HTTPStatus.BAD_REQUEST,
313
+ "Client disconnected")
314
+ final_res = res
315
+ assert final_res is not None
316
+ choices = []
317
+ for output in final_res.outputs:
318
+ choice_data = ChatCompletionResponseChoice(
319
+ index=output.index,
320
+ message=ChatMessage(role="assistant", content=output.text),
321
+ finish_reason=output.finish_reason,
322
+ )
323
+ choices.append(choice_data)
324
+
325
+ num_prompt_tokens = len(final_res.prompt_token_ids)
326
+ num_generated_tokens = sum(
327
+ len(output.token_ids) for output in final_res.outputs)
328
+ usage = UsageInfo(
329
+ prompt_tokens=num_prompt_tokens,
330
+ completion_tokens=num_generated_tokens,
331
+ total_tokens=num_prompt_tokens + num_generated_tokens,
332
+ )
333
+ response = ChatCompletionResponse(
334
+ id=request_id,
335
+ created=created_time,
336
+ model=model_name,
337
+ choices=choices,
338
+ usage=usage,
339
+ )
340
+
341
+ if request.stream:
342
+ # When user requests streaming but we don't stream, we still need to
343
+ # return a streaming response with a single event.
344
+ response_json = response.json(ensure_ascii=False)
345
+
346
+ async def fake_stream_generator() -> AsyncGenerator[str, None]:
347
+ yield f"data: {response_json}\n\n"
348
+ yield "data: [DONE]\n\n"
349
+
350
+ return StreamingResponse(fake_stream_generator(),
351
+ media_type="text/event-stream")
352
+
353
+ return response
354
+
355
+
356
+ @app.post("/v1/completions")
357
+ async def create_completion(request: CompletionRequest, raw_request: Request):
358
+ """Completion API similar to OpenAI's API.
359
+
360
+ See https://platform.openai.com/docs/api-reference/completions/create
361
+ for the API specification. This API mimics the OpenAI Completion API.
362
+
363
+ NOTE: Currently we do not support the following features:
364
+ - echo (since the vLLM engine does not currently support
365
+ getting the logprobs of prompt tokens)
366
+ - suffix (the language models we currently support do not support
367
+ suffix)
368
+ - logit_bias (to be supported by vLLM engine)
369
+ """
370
+ logger.info(f"Received completion request: {request}")
371
+
372
+ error_check_ret = await check_model(request)
373
+ if error_check_ret is not None:
374
+ return error_check_ret
375
+
376
+ if request.echo:
377
+ # We do not support echo since the vLLM engine does not
378
+ # currently support getting the logprobs of prompt tokens.
379
+ return create_error_response(HTTPStatus.BAD_REQUEST,
380
+ "echo is not currently supported")
381
+
382
+ if request.suffix is not None:
383
+ # The language models we currently support do not support suffix.
384
+ return create_error_response(HTTPStatus.BAD_REQUEST,
385
+ "suffix is not currently supported")
386
+
387
+ if request.logit_bias is not None and len(request.logit_bias) > 0:
388
+ # TODO: support logit_bias in vLLM engine.
389
+ return create_error_response(HTTPStatus.BAD_REQUEST,
390
+ "logit_bias is not currently supported")
391
+
392
+ model_name = request.model
393
+ request_id = f"cmpl-{random_uuid()}"
394
+
395
+ use_token_ids = False
396
+ if isinstance(request.prompt, list):
397
+ if len(request.prompt) == 0:
398
+ return create_error_response(HTTPStatus.BAD_REQUEST,
399
+ "please provide at least one prompt")
400
+ first_element = request.prompt[0]
401
+ if isinstance(first_element, int):
402
+ use_token_ids = True
403
+ prompt = request.prompt
404
+ elif isinstance(first_element, (str, list)):
405
+ # TODO: handles multiple prompt case in list[list[int]]
406
+ if len(request.prompt) > 1:
407
+ return create_error_response(
408
+ HTTPStatus.BAD_REQUEST,
409
+ "multiple prompts in a batch is not currently supported")
410
+ use_token_ids = not isinstance(first_element, str)
411
+ prompt = request.prompt[0]
412
+ else:
413
+ prompt = request.prompt
414
+
415
+ if use_token_ids:
416
+ _, error_check_ret = await check_length(request, prompt_ids=prompt)
417
+ else:
418
+ token_ids, error_check_ret = await check_length(request, prompt=prompt)
419
+ if error_check_ret is not None:
420
+ return error_check_ret
421
+
422
+ created_time = int(time.monotonic())
423
+ try:
424
+ # spaces_between_special_tokens = request.spaces_between_special_tokens
425
+ sampling_params = SamplingParams(
426
+ n=request.n,
427
+ best_of=request.best_of,
428
+ presence_penalty=request.presence_penalty,
429
+ frequency_penalty=request.frequency_penalty,
430
+ temperature=request.temperature,
431
+ top_p=request.top_p,
432
+ top_k=request.top_k,
433
+ stop=request.stop,
434
+ stop_token_ids=request.stop_token_ids,
435
+ ignore_eos=request.ignore_eos,
436
+ max_tokens=request.max_tokens,
437
+ logprobs=request.logprobs,
438
+ use_beam_search=request.use_beam_search,
439
+ skip_special_tokens=request.skip_special_tokens,
440
+ # spaces_between_special_tokens=spaces_between_special_tokens,
441
+ )
442
+ except ValueError as e:
443
+ return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
444
+
445
+ if use_token_ids:
446
+ result_generator = engine.generate(None,
447
+ sampling_params,
448
+ request_id,
449
+ prompt_token_ids=prompt)
450
+ else:
451
+ result_generator = engine.generate(prompt, sampling_params, request_id,
452
+ token_ids)
453
+
454
+ # Similar to the OpenAI API, when n != best_of, we do not stream the
455
+ # results. In addition, we do not stream the results when use beam search.
456
+ stream = (request.stream
457
+ and (request.best_of is None or request.n == request.best_of)
458
+ and not request.use_beam_search)
459
+
460
+ def create_stream_response_json(
461
+ index: int,
462
+ text: str,
463
+ logprobs: Optional[LogProbs] = None,
464
+ finish_reason: Optional[str] = None,
465
+ ) -> str:
466
+ choice_data = CompletionResponseStreamChoice(
467
+ index=index,
468
+ text=text,
469
+ logprobs=logprobs,
470
+ finish_reason=finish_reason,
471
+ )
472
+ response = CompletionStreamResponse(
473
+ id=request_id,
474
+ created=created_time,
475
+ model=model_name,
476
+ choices=[choice_data],
477
+ )
478
+ response_json = response.json(ensure_ascii=False)
479
+
480
+ return response_json
481
+
482
+ async def completion_stream_generator() -> AsyncGenerator[str, None]:
483
+ previous_texts = [""] * request.n
484
+ previous_num_tokens = [0] * request.n
485
+ async for res in result_generator:
486
+ res: RequestOutput
487
+ for output in res.outputs:
488
+ i = output.index
489
+ delta_text = output.text[len(previous_texts[i]):]
490
+ if request.logprobs is not None:
491
+ logprobs = create_logprobs(
492
+ output.token_ids[previous_num_tokens[i]:],
493
+ output.logprobs[previous_num_tokens[i]:],
494
+ len(previous_texts[i]))
495
+ else:
496
+ logprobs = None
497
+ previous_texts[i] = output.text
498
+ previous_num_tokens[i] = len(output.token_ids)
499
+ response_json = create_stream_response_json(
500
+ index=i,
501
+ text=delta_text,
502
+ logprobs=logprobs,
503
+ )
504
+ yield f"data: {response_json}\n\n"
505
+ if output.finish_reason is not None:
506
+ logprobs = (LogProbs()
507
+ if request.logprobs is not None else None)
508
+ response_json = create_stream_response_json(
509
+ index=i,
510
+ text="",
511
+ logprobs=logprobs,
512
+ finish_reason=output.finish_reason,
513
+ )
514
+ yield f"data: {response_json}\n\n"
515
+ yield "data: [DONE]\n\n"
516
+
517
+ # Streaming response
518
+ if stream:
519
+ return StreamingResponse(completion_stream_generator(),
520
+ media_type="text/event-stream")
521
+
522
+ # Non-streaming response
523
+ final_res: RequestOutput = None
524
+ async for res in result_generator:
525
+ if await raw_request.is_disconnected():
526
+ # Abort the request if the client disconnects.
527
+ await engine.abort(request_id)
528
+ return create_error_response(HTTPStatus.BAD_REQUEST,
529
+ "Client disconnected")
530
+ final_res = res
531
+ assert final_res is not None
532
+ choices = []
533
+ for output in final_res.outputs:
534
+ if request.logprobs is not None:
535
+ logprobs = create_logprobs(output.token_ids, output.logprobs)
536
+ else:
537
+ logprobs = None
538
+ choice_data = CompletionResponseChoice(
539
+ index=output.index,
540
+ text=output.text,
541
+ logprobs=logprobs,
542
+ finish_reason=output.finish_reason,
543
+ )
544
+ choices.append(choice_data)
545
+
546
+ num_prompt_tokens = len(final_res.prompt_token_ids)
547
+ num_generated_tokens = sum(
548
+ len(output.token_ids) for output in final_res.outputs)
549
+ usage = UsageInfo(
550
+ prompt_tokens=num_prompt_tokens,
551
+ completion_tokens=num_generated_tokens,
552
+ total_tokens=num_prompt_tokens + num_generated_tokens,
553
+ )
554
+ response = CompletionResponse(
555
+ id=request_id,
556
+ created=created_time,
557
+ model=model_name,
558
+ choices=choices,
559
+ usage=usage,
560
+ )
561
+
562
+ if request.stream:
563
+ # When user requests streaming but we don't stream, we still need to
564
+ # return a streaming response with a single event.
565
+ response_json = response.json(ensure_ascii=False)
566
+
567
+ async def fake_stream_generator() -> AsyncGenerator[str, None]:
568
+ yield f"data: {response_json}\n\n"
569
+ yield "data: [DONE]\n\n"
570
+
571
+ return StreamingResponse(fake_stream_generator(),
572
+ media_type="text/event-stream")
573
+
574
+ return response
575
+
576
+
577
+ if __name__ == "__main__":
578
+ parser = argparse.ArgumentParser(
579
+ description="vLLM OpenAI-Compatible RESTful API server.")
580
+ parser.add_argument("--host", type=str, default=None, help="host name")
581
+ parser.add_argument("--port", type=int, default=8000, help="port number")
582
+ parser.add_argument("--allow-credentials",
583
+ action="store_true",
584
+ help="allow credentials")
585
+ parser.add_argument("--allowed-origins",
586
+ type=json.loads,
587
+ default=["*"],
588
+ help="allowed origins")
589
+ parser.add_argument("--allowed-methods",
590
+ type=json.loads,
591
+ default=["*"],
592
+ help="allowed methods")
593
+ parser.add_argument("--allowed-headers",
594
+ type=json.loads,
595
+ default=["*"],
596
+ help="allowed headers")
597
+ parser.add_argument("--served-model-name",
598
+ type=str,
599
+ default=None,
600
+ help="The model name used in the API. If not "
601
+ "specified, the model name will be the same as "
602
+ "the huggingface name.")
603
+
604
+ parser = AsyncEngineArgs.add_cli_args(parser)
605
+ args = parser.parse_args()
606
+
607
+ app.add_middleware(
608
+ CORSMiddleware,
609
+ allow_origins=args.allowed_origins,
610
+ allow_credentials=args.allow_credentials,
611
+ allow_methods=args.allowed_methods,
612
+ allow_headers=args.allowed_headers,
613
+ )
614
+
615
+ logger.info(f"args: {args}")
616
+
617
+ if args.served_model_name is not None:
618
+ served_model = args.served_model_name
619
+ else:
620
+ served_model = args.model
621
+
622
+ engine_args = AsyncEngineArgs.from_cli_args(args)
623
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
624
+ engine_model_config = asyncio.run(engine.get_model_config())
625
+ max_model_len = engine_model_config.max_model_len
626
+
627
+ # A separate tokenizer to map token IDs to strings.
628
+ tokenizer = get_tokenizer(engine_args.tokenizer,
629
+ tokenizer_mode=engine_args.tokenizer_mode,
630
+ trust_remote_code=engine_args.trust_remote_code)
631
+
632
+ uvicorn.run(app,
633
+ host=args.host,
634
+ port=args.port,
635
+ log_level="info",
636
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReAct Agent Template
2
+ name: react_template
3
+ version: 0.0.1
4
+ type: react
5
+ description: A react agent capable of code interpreter
6
+ module_name: infiagent.agent.react
7
+ class_name: AsyncReactAgent
8
+ target_tasks:
9
+ - code interpreter
10
+ llm:
11
+ model_name: gpt-35-turbo
12
+ module_name: in f i a gen r.llm
13
+ class_name: AzureOpenAIGPTClient
14
+ params:
15
+ temperature: 0.2
16
+ top_p: 0.95
17
+ repetition_penalty: 1.0
18
+ max_tokens: 4096
19
+ prompt_template: !prompt ZeroShotReactPrompt
20
+ plugins:
21
+ - name: python_code_sandbox
22
+ type: tool
23
+ config: configs/tool_configs/async_python_code_sandbox.yaml
configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReAct Agent Template
2
+ name: gpt_4_react
3
+ version: 0.0.1
4
+ type: react
5
+ description: A react agent capable of code interpreter
6
+ module_name: infiagent.agent.react
7
+ class_name: AsyncReactAgent
8
+ target_tasks:
9
+ - code interpreter
10
+ llm:
11
+ model_name: gpt-4-0613
12
+ module_name: infiagent.llm
13
+ class_name: AzureOpenAIGPTClient
14
+ params:
15
+ temperature: 0.2
16
+ top_p: 0.95
17
+ repetition_penalty: 1.0
18
+ max_tokens: 4096
19
+ prompt_template: !prompt ZeroShotReactPrompt
20
+ plugins:
21
+ - name: python_code_sandbox
22
+ type: tool
23
+ config: configs/tool_configs/async_python_code_sandbox.yaml
configs/agent_configs/react_agent_azureopenai_gpt_4_async_dcoker.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReAct Agent Template
2
+ name: gpt_4_react
3
+ version: 0.0.1
4
+ type: react
5
+ description: A react agent capable of code interpreter
6
+ module_name: infiagent.agent.react
7
+ class_name: AsyncReactAgent
8
+ target_tasks:
9
+ - code interpreter
10
+ llm:
11
+ model_name: gpt-4-0613
12
+ module_name: infiagent.llm
13
+ class_name: AzureOpenAIGPTClient
14
+ params:
15
+ temperature: 0.2
16
+ top_p: 0.95
17
+ repetition_penalty: 1.0
18
+ max_tokens: 4096
19
+ prompt_template: !prompt ZeroShotReactPrompt
20
+ plugins:
21
+ - name: python_code_sandbox
22
+ type: tool
23
+ config: configs/tool_configs/async_python_code_sandbox_docker.yaml
configs/agent_configs/react_agent_gpt4_async.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReAct Agent Template
2
+ name: react_template
3
+ version: 0.0.1
4
+ type: react
5
+ description: A react agent capable of code interpreter
6
+ module_name: infiagent.agent.react
7
+ class_name: AsyncReactAgent
8
+ target_tasks:
9
+ - code interpreter
10
+ llm:
11
+ model_name: gpt-4
12
+ module_name: infiagent.llm
13
+ class_name: OpenAIGPTClient
14
+ params:
15
+ temperature: 0.0
16
+ top_p: 0.9
17
+ repetition_penalty: 1.0
18
+ max_tokens: 1024
19
+ prompt_template: !prompt ZeroShotReactPrompt
20
+ plugins:
21
+ - name: python_code_sandbox
22
+ type: tool
23
+ config: configs/tool_configs/async_python_code_sandbox.yaml
configs/agent_configs/react_agent_llama_async.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReAct Agent Template
2
+ name: react_template
3
+ version: 0.0.1
4
+ type: react
5
+ description: A react agent capable of code interpreter
6
+ module_name: infiagent.agent.react
7
+ class_name: AsyncReactAgent
8
+ target_tasks:
9
+ - code interpreter
10
+ llm:
11
+ model_name: meta-llama/Llama-2-7b-hf
12
+ module_name: infiagent.llm
13
+ class_name: LlamaOpenAIClient
14
+ params:
15
+ temperature: 0.0
16
+ top_p: 0.9
17
+ repetition_penalty: 1.0
18
+ max_tokens: 1024
19
+ prompt_template: !prompt ZeroShotReactPrompt
20
+ plugins:
21
+ - name: python_code_sandbox
22
+ type: tool
23
+ config: configs/tool_configs/async_python_code_sandbox.yaml
configs/agent_configs/react_agent_opt_async.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReAct Agent Template
2
+ name: react_template
3
+ version: 0.0.1
4
+ type: react
5
+ description: A react agent capable of code interpreter
6
+ module_name: infiagent.agent.react
7
+ class_name: AsyncReactAgent
8
+ target_tasks:
9
+ - code interpreter
10
+ llm:
11
+ model_name: facebook/opt-125m
12
+ module_name: infiagent.llm
13
+ class_name: OptOpenAIClient
14
+ params:
15
+ temperature: 0.0
16
+ top_p: 0.9
17
+ repetition_penalty: 1.0
18
+ max_tokens: 1024
19
+ prompt_template: !prompt ZeroShotReactPrompt
20
+ plugins:
21
+ - name: python_code_sandbox
22
+ type: tool
23
+ config: configs/tool_configs/async_python_code_sandbox.yaml
configs/tool_configs/async_python_code_sandbox.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ name: python_code_sandbox
2
+ version: 0.0.1
3
+ type: tool
4
+ description: this tool can help to run python script with python code as input
5
+ module_name: infiagent.tools
6
+ class_name: AsyncPythonSandBoxTool
7
+ session_id: none
configs/tool_configs/async_python_code_sandbox_docker.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ name: python_code_sandbox
2
+ version: 0.0.1
3
+ type: tool
4
+ description: this tool can help to run python script with python code as input
5
+ module_name: infiagent.tools
6
+ class_name: CodeTool
7
+ session_id: none
run.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ set -ex
3
+ poetry run python3 -m uvicorn src.activities.api:app --reload --host 0.0.0.0 --port ${PORT:-3000} --limit-max-requests 5000 --timeout-keep-alive 1200
run_demo.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # set -ex
3
+
4
+ streamlit run ./activities/local_demo.py --server.port 6006 -- $@
5
+
run_local.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -ex
3
+ poetry run python3 -m uvicorn src.activities.local_test:local_app --reload --host 0.0.0.0 --port ${PORT:-3000} --limit-max-requests 5000 --timeout-keep-alive 1200
4
+
setup.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='infiagent',
5
+ version='0.1.0',
6
+ author='InfiAgent',
7
+ packages=find_packages(where='src'),
8
+ package_dir={'': 'src'},
9
+ url='https://github.com/InfiAgent/ADA-Agent',
10
+ license='LICENSE.txt',
11
+ description='An awesome package for InfiAgent.',
12
+ long_description=open('README.md').read(),
13
+ package_data={
14
+ 'infiagent.configs.agent_configs': ['*.yaml'],
15
+ 'infiagent.configs.tool_configs': ['*.yaml'],
16
+ },
17
+ install_requires=[
18
+ "streamlit",
19
+ "pyyaml",
20
+ "pytest",
21
+ "openai==0.27.7",
22
+ "fastapi",
23
+ "uvicorn",
24
+ "uvloop",
25
+ "watchdog",
26
+ "chardet",
27
+ "werkzeug",
28
+ "python-dotenv",
29
+ "motor",
30
+ "aiofiles",
31
+ "sse_starlette",
32
+ "loguru",
33
+ "jupyter_client",
34
+ "pandas",
35
+ "scikit-learn",
36
+ "scipy",
37
+ "ipykernel"
38
+ ],
39
+ python_requires='>=3.9'
40
+ )
src/infiagent/__init__.py ADDED
File without changes
src/infiagent/agent/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base_agent import BaseAgent
2
+ from .react import AsyncReactAgent
src/infiagent/agent/base_agent.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from abc import ABC, abstractmethod
3
+ from typing import Dict, Callable, Union, AsyncGenerator
4
+
5
+ from ..exceptions.exceptions import InputErrorException
6
+ from ..prompt import PromptTemplate
7
+ from ..schemas import AgentOutput, AgentType, AgentResponse
8
+
9
+ from ..llm.base_llm import BaseLLM
10
+
11
+ from ..tools import BaseTool
12
+ from ..utils import Config, get_logger
13
+
14
+ import os
15
+ from importlib import import_module
16
+
17
+ logger = get_logger()
18
+
19
+
20
+ LLM_CONF_OVERRIDE_KEY = ['psm', 'dc', 'temperature', 'top_p', 'top_k', 'max_tokens']
21
+
22
+
23
+ class BaseAgent(ABC):
24
+ """Base Agent class defining the essential attributes and methods for an ALM Agent.
25
+ """
26
+
27
+ def __init__(self, **kwargs):
28
+ """
29
+ Initializes an instance of the Agent class.
30
+ """
31
+ # Set default values
32
+ default_config = {
33
+ 'name': 'agent',
34
+ 'type': AgentType.react,
35
+ 'version': '',
36
+ 'description': '',
37
+ 'prompt_template': None,
38
+ 'auth': {}
39
+ }
40
+ # Update default values with provided config
41
+ default_config.update(kwargs)
42
+
43
+ # Access configuration data with a known default value
44
+ auth = default_config['auth']
45
+ self._set_auth_env(auth)
46
+
47
+ self._name: str = default_config['name']
48
+ self._type: AgentType = default_config['type']
49
+ self._version: str = default_config['version']
50
+ self._description: str = default_config['description']
51
+ self.__prompt_template: Union[PromptTemplate, None] = \
52
+ self._get_prompt_template(default_config['prompt_template'])
53
+ self.__llm: Union[BaseLLM, None] = None
54
+ self.__plugins_map: Dict = {}
55
+ self.__plugin_tool_function = {}
56
+ self.__plugin_tool_async_function = {}
57
+ self.__plugin_tool_description = None
58
+
59
+ @property
60
+ def name(self) -> str:
61
+ return self._name
62
+
63
+ @property
64
+ def type(self) -> AgentType:
65
+ return self._type
66
+
67
+ @property
68
+ def version(self) -> str:
69
+ return self._version
70
+
71
+ @property
72
+ def description(self) -> str:
73
+ return self._description
74
+
75
+ @property
76
+ def prompt_template(self) -> PromptTemplate:
77
+ return self.__prompt_template
78
+
79
+ @property
80
+ def llm(self) -> Union[BaseLLM, None]:
81
+ return self.__llm
82
+
83
+ @llm.setter
84
+ def llm(self, llm_client: BaseLLM):
85
+ if llm_client is None or not isinstance(llm_client, BaseLLM):
86
+ raise InputErrorException("Invalid llm client {}".format(type(llm_client)))
87
+ self.__llm = llm_client
88
+
89
+ @property
90
+ def plugins_map(self) -> Dict:
91
+ return self.__plugins_map.copy() # Return a copy to prevent external modification
92
+
93
+ def add_plugin(self, tool_name: str, tool):
94
+ if not tool_name or not tool:
95
+ raise InputErrorException("Adding invalid tool name: {}, type {}".format(tool_name, type(tool)))
96
+ self.__plugins_map[tool_name] = tool
97
+
98
+ def _set_auth_env(self, obj):
99
+ """This method sets environment variables for authentication.
100
+ """
101
+ for key in obj:
102
+ os.environ[key] = obj.get(key)
103
+
104
+ def _get_prompt_template(self, obj):
105
+ """This method returns a prompt template instance based on the provided configuration.
106
+ """
107
+ assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
108
+ if isinstance(obj, dict):
109
+ return {
110
+ key: self._parse_prompt_template(obj[key]) for key in obj
111
+ }
112
+ elif isinstance(obj, PromptTemplate):
113
+ ans = self._parse_prompt_template(obj)
114
+ return ans
115
+ else:
116
+ raise InputErrorException("Invalid PromptTemplate, it should be a dict or PromptTemplate. But get {}"
117
+ .format(type(obj)))
118
+
119
+ def _parse_prompt_template(self, obj: Union[dict, PromptTemplate]):
120
+ """This method parses the prompt template configuration and returns a prompt template instance.
121
+ """
122
+ assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
123
+ if isinstance(obj, PromptTemplate):
124
+ return obj
125
+ return PromptTemplate(input_variables=obj['input_variables'],
126
+ template=obj['template'],
127
+ validate_template=bool(obj.get('validate_template', True)))
128
+
129
+ @classmethod
130
+ def _get_basic_instance_from_config(cls, config_data):
131
+ agent_module_name = config_data.get("module_name", None)
132
+ agent_class_name = config_data.get("class_name", None)
133
+ if not agent_module_name or not agent_class_name:
134
+ raise InputErrorException("Agent module_name and class_name required, please check your config")
135
+
136
+ module = import_module(agent_module_name)
137
+ clazz = getattr(module, agent_class_name)
138
+ agent_instance = clazz(**config_data)
139
+ return agent_instance
140
+
141
+ @classmethod
142
+ def from_config_path_and_kwargs(cls, config_path, **kwargs):
143
+ config_data = Config.load(config_path)
144
+ logger.info(f"Use config from path {config_path} to init agent : {config_data}")
145
+ agent_instance = cls._get_basic_instance_from_config(config_data)
146
+
147
+ if 'llm' in config_data and 'params' in config_data['llm']:
148
+ for param in LLM_CONF_OVERRIDE_KEY:
149
+ if param in kwargs and kwargs[param]:
150
+ logger.info(f"Overwrite with new {param} {kwargs[param]}")
151
+ config_data['llm']['params'][param] = kwargs[param]
152
+
153
+ assert isinstance(agent_instance, BaseAgent)
154
+ agent_instance._init_llm(config_data.get("llm", {}))
155
+ agent_instance._init_plugins(config_data.get('plugins', []))
156
+ return agent_instance
157
+
158
+ def _init_llm(self, obj):
159
+ """
160
+ This method parses the Language Model Manager (LLM) configuration and returns an LLM instance.
161
+
162
+ :param obj: A configuration dictionary or string.
163
+ :type obj: dict or str
164
+ :raises ValueError: If the specified LLM is not supported.
165
+ :return: An LLM instance.
166
+ :rtype: BaseLLM
167
+ """
168
+ if isinstance(obj, str):
169
+ name = obj
170
+ model_params = dict()
171
+ else:
172
+ name = obj.get('model_name', None)
173
+ model_params = obj.get('params', dict())
174
+
175
+ module_name = obj['module_name']
176
+ class_name = obj['class_name']
177
+
178
+ module = import_module(module_name)
179
+ clazz = getattr(module, class_name)
180
+
181
+ llm = clazz(model_name=name, params=model_params)
182
+ self.llm = llm
183
+
184
+ def _init_plugins(self, configs):
185
+ """
186
+ This method parses the plugin configuration and add each plugin into the plugins_map.
187
+ """
188
+ assert isinstance(configs, list)
189
+ for plugin_config in configs:
190
+ if plugin_config.get('type', "") == 'agent':
191
+ # Agent as plugin
192
+ agent = BaseAgent.from_config_path_and_kwargs(plugin_config['config'])
193
+ self.plugins_map[plugin_config['name']] = agent
194
+ else:
195
+ # Tools as plugin
196
+ params = plugin_config.get('params', dict())
197
+ tool = BaseTool.from_config(config_input=plugin_config['config'], **params)
198
+ self.plugins_map[tool.name] = tool
199
+
200
+ @classmethod
201
+ async def async_from_config_path_and_kwargs(cls, config_path, **kwargs):
202
+ config_data = Config.load(config_path)
203
+ logger.info(f"Use config from path {config_path} to init agent : {config_data}")
204
+ agent_instance = cls._get_basic_instance_from_config(config_data)
205
+
206
+ # override default config with user input
207
+ if 'llm' in config_data and 'params' in config_data['llm']:
208
+ for param in LLM_CONF_OVERRIDE_KEY:
209
+ if param in kwargs and kwargs[param]:
210
+ logger.info(f"Overwrite with new {param} {kwargs[param]}")
211
+ config_data['llm']['params'][param] = kwargs[param]
212
+
213
+ # Create tasks for llm and each individual plugin
214
+ llm_config = config_data.get("llm", {})
215
+ plugin_configs = config_data.get('plugins', [])
216
+
217
+
218
+ # Create tasks for llm and each individual plugin
219
+ llm_task = asyncio.create_task(cls._async_init_llm(llm_config))
220
+ plugin_tasks = [asyncio.create_task(cls._async_init_plugin(plugin_config)) for
221
+ plugin_config in plugin_configs]
222
+
223
+
224
+ # Gather results
225
+ llm, *plugins = await asyncio.gather(llm_task, *plugin_tasks)
226
+
227
+ agent_instance.llm = llm
228
+ for plugin in plugins:
229
+ plugin_name, plugin_instance = plugin
230
+ agent_instance.add_plugin(plugin_name, plugin_instance)
231
+ return agent_instance
232
+
233
+ @classmethod
234
+ async def _async_init_llm(cls, llm_config):
235
+ llm_model_name = llm_config.get("module_name", None)
236
+ llm_class_name = llm_config.get("class_name", None)
237
+ if not llm_model_name or not llm_class_name:
238
+ raise InputErrorException("Agent LLM module_name and class_name required, please check your config")
239
+ module = import_module(llm_model_name)
240
+ clazz = getattr(module, llm_class_name)
241
+ assert issubclass(clazz, BaseLLM), f"{clazz} is not a subclass of BaseLLM"
242
+ llm_instance = await clazz.create(config_data=llm_config)
243
+ return llm_instance
244
+
245
+ @classmethod
246
+ async def _async_init_plugin(cls, plugin_config):
247
+
248
+ if plugin_config.get('type', "") == 'agent':
249
+ # Agent as plugin
250
+ agent = await BaseAgent.async_from_config_path_and_kwargs(plugin_config['config'])
251
+ return plugin_config['name'], agent
252
+ else:
253
+ # Tool as plugin
254
+ params = plugin_config.get('params', dict())
255
+ name = plugin_config.get('name', None)
256
+ config = plugin_config['config']
257
+
258
+ tool = await BaseTool.async_from_config(config_input=config, **params)
259
+
260
+ if name is None:
261
+ name = tool.name
262
+ logger.info("Init tool with name [{}], and description [{}]".format(name, tool.description))
263
+ return name, tool
264
+
265
+ @abstractmethod
266
+ def run(self, *args, **kwargs) -> [AgentResponse, None]:
267
+ """Abstract method to be overridden by child classes for running the agent.
268
+
269
+ :return: The output of the agent.
270
+ :rtype: AgentOutput
271
+ """
272
+ pass
273
+
274
+ async def async_run(self, *args, **kwargs) -> AsyncGenerator[AgentResponse, None]:
275
+ """Abstract method to be overridden by child classes for running the agent.
276
+
277
+ :return: The output of the agent.
278
+ """
279
+ yield self.run(*args, **kwargs)
280
+
281
+ def _get_plugin_function_map(self, method_name: str) -> Dict[str, Callable]:
282
+ if method_name == "run" and self.__plugin_tool_function:
283
+ return self.__plugin_tool_function
284
+ elif method_name == "async_run" and self.__plugin_tool_async_function:
285
+ return self.__plugin_tool_async_function
286
+
287
+ function_map = {}
288
+
289
+ for name, plugin_tool in self.plugins_map.items():
290
+ if isinstance(plugin_tool, (BaseTool, BaseAgent)):
291
+ function_map[name] = getattr(plugin_tool, method_name)
292
+ else:
293
+ logger.warning(f"No support for plugin name {name} of type {type(plugin_tool)}")
294
+
295
+ if method_name == "run":
296
+ self.__plugin_tool_function = function_map
297
+ elif method_name == "async_run":
298
+ self.__plugin_tool_async_function = function_map
299
+
300
+ return function_map
301
+
302
+ def get_plugin_tool_function(self) -> Dict[str, Callable]:
303
+ """Format the function map for the function API.
304
+
305
+ :return: The function map.
306
+ :rtype: Dict[str, Callable]
307
+ """
308
+ return self._get_plugin_function_map("run")
309
+
310
+ def get_plugin_tool_async_function(self) -> Dict[str, Callable]:
311
+ """Format the function map for the function API.
312
+
313
+ :return: The function map.
314
+ :rtype: Dict[str, Callable]
315
+ """
316
+ return self._get_plugin_function_map("async_run")
317
+
318
+ def _get_plugin_description(self):
319
+ if self.__plugin_tool_description:
320
+ return self.__plugin_tool_description
321
+
322
+ descriptions = ""
323
+ try:
324
+ for plugin_name, plugin in self.plugins_map.items():
325
+ descriptions += f"{plugin_name}[input]: {plugin.description}\n"
326
+ except Exception as e:
327
+ err_msg = "Failed to get plugin tool name and description. error: {}".format(str(e))
328
+ raise InputErrorException(err_msg) from e
329
+
330
+ self.__plugin_tool_description = descriptions
331
+ return descriptions
332
+
333
+ def clear(self):
334
+ """
335
+ Clear and reset the agent.
336
+ """
337
+ pass
src/infiagent/agent/react/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .async_react_agent import AsyncReactAgent
2
+ __all__ = [
3
+ 'AsyncReactAgent'
4
+ ]
src/infiagent/agent/react/async_react_agent.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ from typing import Union, List, Dict
4
+
5
+ from werkzeug.datastructures import FileStorage
6
+
7
+ from .. import BaseAgent
8
+ from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException
9
+ from ...schemas import (
10
+ AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse,
11
+ BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile
12
+ )
13
+ from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool
14
+ from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \
15
+ OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \
16
+ TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN
17
+
18
+ SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
19
+ FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
20
+ CODE_BLOCK_START_TAG = '```python'
21
+ CODE_BLOCK_TAG = '```'
22
+
23
+ logger = get_logger()
24
+
25
+ SAND_BOX_PLUGIN_NAME = 'python_code_sandbox'
26
+ FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"]
27
+ CODE_BLOCK_START_TAG = '```python'
28
+ CODE_BLOCK_TAG = '```'
29
+ STOP_WORD = ['Observation:']
30
+
31
+ logger = get_logger()
32
+
33
+
34
+ class AsyncReactAgent(BaseAgent):
35
+ def __init__(self, **kwargs):
36
+ super().__init__(**kwargs)
37
+ self._name = self._name or "AsyncReactAgent"
38
+ self._type = AgentType.react
39
+ self.__intermediate_steps: List[BaseAgentResponse] = []
40
+
41
+ @property
42
+ def intermediate_steps(self):
43
+ return self.__intermediate_steps
44
+
45
+ def run(self, *args, **kwargs):
46
+ pass
47
+
48
+ async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]):
49
+ sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME)
50
+ if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)):
51
+ raise InternalErrorException("SandBox client is not ready for agent, please check init logic.")
52
+ return await sandbox_plugin.sync_to_sandbox(file)
53
+
54
+ async def async_run(self, agent_req: AgentRequest):
55
+ instruction = '\n'.join(message.content for message in agent_req.messages)
56
+ async for response in self._chat(instruction, is_cn=agent_req.is_cn):
57
+ yield response
58
+
59
+ async def _chat(self, instruction: str, is_cn=False, max_iterations=10,
60
+ max_single_step_iterations=3):
61
+ current_iteration = 0
62
+
63
+ for _ in range(max_iterations):
64
+ current_iteration += 1
65
+ llm_response = await self._single_round_thought(instruction,
66
+ max_llm_iteration=max_single_step_iterations,
67
+ is_cn=is_cn)
68
+ logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n"
69
+ .format(current_iteration, max_iterations, llm_response.raw_output,
70
+ llm_response.formatted_output))
71
+ yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output)
72
+
73
+ if isinstance(llm_response, AgentFinish):
74
+ logger.info("Find final answer, stop iteration.")
75
+ break
76
+
77
+ self.intermediate_steps.append(llm_response)
78
+ action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration,
79
+ max_iterations, is_cn)
80
+ logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n"
81
+ .format(current_iteration, max_iterations, action_response.raw_output,
82
+ action_response.formatted_output))
83
+ self.intermediate_steps.append(action_response)
84
+
85
+ yield self.create_agent_response(action_response.formatted_output,
86
+ cur_output_files,
87
+ action_response.raw_output)
88
+
89
+ logger.info(f"Finished iteration in {current_iteration}.")
90
+
91
+ # TODO update logic to not be sandbox specific, sandbox related logic should be handled in sandbox client
92
+ async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False):
93
+ try:
94
+ response.tool = 'python_code_sandbox'
95
+ action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input)
96
+ logger.info(
97
+ f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n"
98
+ f"{action_response.output_text}")
99
+
100
+ if "STDERR" in action_response.output_text:
101
+ formatted_output = self._process_sandbox_output(action_response.output_text)
102
+ else:
103
+ formatted_output = action_response.output_text
104
+
105
+ formatted_output = replace_latex_format(formatted_output)
106
+ observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN
107
+ formatted_output = f"{observation_prefix}\n{formatted_output}\n"
108
+
109
+ action_observation = AgentObservation(tool=response.tool,
110
+ formatted_output=formatted_output,
111
+ raw_output=action_response.output_text)
112
+ cur_output_files = self._get_output_files(action_response)
113
+ return action_observation, cur_output_files
114
+
115
+ except Exception as e:
116
+ logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. "
117
+ f"Error: {str(e)}", exc_info=True)
118
+ # TODO: We hard code here as we only have one tool
119
+ raise SandboxException("Error occurred while running the tool") from e
120
+
121
+ def _compose_prompt(self, instruction) -> str:
122
+ """
123
+ Compose the prompt from template, worker description, examples and instruction.
124
+ """
125
+ agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps)
126
+ tool_description = self._get_plugin_description()
127
+ tool_names = ", ".join(list(self.plugins_map.keys()))
128
+ if self.prompt_template is None:
129
+ raise InternalErrorException("Agent prompt is none, please check init process")
130
+
131
+ return self.prompt_template.format(
132
+ instruction=instruction,
133
+ agent_scratchpad=agent_scratchpad,
134
+ tool_description=tool_description,
135
+ tool_names=tool_names
136
+ )
137
+
138
+ async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \
139
+ Union[AgentAction, AgentFinish]:
140
+
141
+ llm_iteration_count = 0
142
+
143
+ llm_response = None
144
+ while llm_iteration_count <= max_llm_iteration:
145
+ llm_iteration_count += 1
146
+ try:
147
+ llm_response = await self._get_llm_response(instruction)
148
+ action_response = self._parse_output(llm_response.content, is_cn)
149
+
150
+ return action_response
151
+ except Exception as e:
152
+ logger.error("LLM iteration {} out of {} failed. Error: {}".
153
+ format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True)
154
+
155
+ if llm_iteration_count > max_llm_iteration:
156
+ logger.error("LLM iteration {} exceed max retry {}. Aborting".
157
+ format(llm_iteration_count, max_llm_iteration))
158
+ return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN,
159
+ raw_output=str(llm_response))
160
+
161
+ async def _get_llm_response(self, instruction: str):
162
+ prompt = self._compose_prompt(instruction)
163
+ logger.info("Send prompt to LLM:\n{}".format(prompt))
164
+ response = await self.llm.async_completion(prompt)
165
+ if response.state == "error":
166
+ raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content)))
167
+
168
+ logger.info("Got response from llm, raw response content: \n{}".format(response.content))
169
+ return response
170
+
171
+ def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]:
172
+
173
+ for stop_word in STOP_WORD:
174
+ if stop_word in llm_output:
175
+ llm_output = llm_output.split(stop_word)[0].rstrip()
176
+ break
177
+
178
+ # Check for Final Answer, if it is final, then just return
179
+ for indicator in FINAL_ANSWER_INDICATORS:
180
+ if indicator in llm_output:
181
+ # got final answer and remove the indicator
182
+ parts = llm_output.split(indicator)
183
+ # formatted_output = ''.join(parts[:-1]).strip()
184
+ formatted_output = ''.join(parts).strip()
185
+ formatted_output = replace_latex_format(formatted_output)
186
+ return AgentFinish(raw_output=llm_output, formatted_output=formatted_output)
187
+
188
+ # Updated regex pattern for capturing the expected input format
189
+ ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
190
+ ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$"
191
+
192
+ action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL)
193
+
194
+ # Find action, context, and action input, build action response
195
+ if action_match:
196
+ context = action_match.group(1).strip()
197
+ action_tool_description = action_match.group(2).strip()
198
+ action_input = action_match.group(3).strip()
199
+
200
+ # Format code
201
+ # TODO: currently we only have one plugin which is sandbox, update to support multiple tools
202
+ format_code_block = self._format_code_block(action_input)
203
+
204
+ prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN
205
+ formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block)
206
+ formatted_output = replace_latex_format(formatted_output)
207
+
208
+ return AgentAction(tool=action_tool_description,
209
+ tool_input=format_code_block,
210
+ formatted_output=formatted_output,
211
+ raw_output=llm_output)
212
+
213
+ # Not final answer and not action, raise exception
214
+ if not re.search(r"Action\s*:", llm_output, re.DOTALL):
215
+ raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`")
216
+ elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL):
217
+ raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`")
218
+ else:
219
+ raise LLMException(f"Unrecognized LLM output format: `{llm_output}`")
220
+
221
+ def _format_code_block(self, tool_input):
222
+ stripped_tool_input = tool_input.strip()
223
+
224
+ if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
225
+ if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'):
226
+ stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \
227
+ '\n'
228
+ formatted_code = stripped_tool_input
229
+ elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith(
230
+ CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG):
231
+ formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n'
232
+ else:
233
+ formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n'
234
+
235
+ return formatted_code.encode("utf-8").decode("utf-8")
236
+
237
+ def _process_sandbox_output(self, output: str):
238
+ """Function to process the result containing STDERR."""
239
+ if len(output) <= 1000:
240
+ return output
241
+
242
+ logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}".
243
+ format(output))
244
+ rows = output.split("\n")
245
+ # Get the first 500 characters, respecting line boundaries
246
+ top_segment = []
247
+ length = 0
248
+ for sub_p in rows:
249
+ if length + len(sub_p) > 500:
250
+ break
251
+ top_segment.append(sub_p)
252
+ length += len(sub_p)
253
+
254
+ # Get the last 500 characters, respecting line boundaries
255
+ bottom_segment = []
256
+ length = 0
257
+ for sub_p in reversed(rows):
258
+ if length + len(sub_p) > 500:
259
+ break
260
+ bottom_segment.insert(0, sub_p)
261
+ length += len(sub_p)
262
+
263
+ # Combine the segments with "......" in between
264
+ timed_output = "\n".join(top_segment + ["......"] + bottom_segment)
265
+
266
+ return timed_output
267
+
268
+ def _get_output_files(self, tool_response) -> list[MediaFile]:
269
+ output_files = []
270
+
271
+ if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput):
272
+ raw_output = tool_response.raw_output
273
+
274
+ if raw_output.code == 0 and not raw_output.data.is_partial:
275
+ result_data = raw_output.data.result
276
+
277
+ # TODO confirm if we still need output and format
278
+ if len(result_data.new_generated_files) > 0:
279
+ output_files.extend([MediaFile(tos_path=file.download_link) for file in
280
+ result_data.new_generated_files])
281
+
282
+ if len(result_data.code_output_result) > 0:
283
+ output_files.extend(
284
+ [MediaFile(tos_path=image.content) for image in result_data.code_output_result
285
+ if image.type == 'image'])
286
+
287
+ return output_files
288
+
289
+ def _replace_csv_path(self, input_string):
290
+ # Search for the pattern and replace it
291
+ pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)'
292
+ replacement = "pd.read_csv('/path/to/your/dataset')"
293
+ updated_string = re.sub(pattern, replacement, input_string)
294
+ return updated_string
295
+
296
+ @staticmethod
297
+ def create_agent_response(formatted_output, output_files, raw_output):
298
+ return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output)
299
+
src/infiagent/conversation_sessions/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .code_interpreter_session import CodeInterpreterSession
src/infiagent/conversation_sessions/code_interpreter_session.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import Any, Dict, Union
5
+
6
+ from werkzeug.datastructures import FileStorage
7
+
8
+ from ..agent import BaseAgent
9
+ from ..agent.react import AsyncReactAgent
10
+ from ..schemas import AgentRequest, MediaFile, Message, RoleType
11
+ from ..utils import generate_random_string, get_logger, get_model_config_path
12
+
13
+ logger = get_logger()
14
+
15
+
16
+ class CodeInterpreterSession:
17
+ def __init__(
18
+ self,
19
+ session_id: Union[None, str] = None,
20
+ model_name: Union[None, str] = "openai",
21
+ config_path: Union[None, str] = None,
22
+ agent: AsyncReactAgent = None,
23
+ **kwargs):
24
+ self.session_id = session_id
25
+ self.config_path = config_path
26
+ self.input_files = []
27
+ self.output_files = []
28
+ self.messages = []
29
+ self.agent = agent
30
+ self.llm_model_name = self.agent.llm.model_name
31
+
32
+ logger.info("Use model {} and llm in config {} for conversation {}"
33
+ .format(model_name, self.llm_model_name, self.config_path, self.session_id))
34
+
35
+ @classmethod
36
+ async def create(cls,
37
+ model_name: Union[None, str] = "openai",
38
+ config_path: Union[None, str] = None,
39
+ **kwargs: Dict[str, Any]):
40
+ if config_path is None:
41
+ config_path = get_model_config_path(model_name)
42
+ logger.info(f"Use Config Path: {config_path}")
43
+
44
+ sandbox_id = generate_random_string(12)
45
+
46
+ # setup agent
47
+ agent = await BaseAgent.async_from_config_path_and_kwargs(config_path, **kwargs)
48
+ await agent.plugins_map["python_code_sandbox"].set_sandbox_id(sandbox_id)
49
+
50
+ return cls(session_id=sandbox_id,
51
+ model_name=model_name,
52
+ config_path=config_path,
53
+ agent=agent)
54
+
55
+ async def upload_to_sandbox(self, file: Union[str, FileStorage]):
56
+ dst_path = await self.agent.sync_to_sandbox(file)
57
+ message = f'User uploaded the following files: {dst_path}\n'
58
+ logging.info(f"The file path {file} has been synced to sandbox with file path {dst_path}")
59
+ self.messages.append(Message(RoleType.System, message))
60
+ self.input_files.append(MediaFile(file_name=os.path.basename(dst_path), sandbox_path=dst_path))
61
+
62
+ async def chat(self, user_messages, input_files=None):
63
+ start_time = time.time()
64
+
65
+ self.messages.extend(user_messages)
66
+ agent_request = AgentRequest(
67
+ messages=self.messages,
68
+ input_files=self.input_files,
69
+ sandbox_id=self.session_id
70
+ )
71
+ logger.info(f"Agent request: {agent_request.__dict__}")
72
+
73
+ async for agent_response in self.agent.async_run(agent_request):
74
+ logger.info(f"Agent response:\n{agent_response.output_text}")
75
+ self.messages.append(Message(RoleType.System, agent_response.output_text))
76
+ yield agent_response
77
+
78
+ exec_time = time.time()
79
+ logger.info(
80
+ f'Agent Execution Latency: {exec_time - start_time}'
81
+ )
82
+
83
+ def __enter__(self):
84
+ pass
85
+
86
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
87
+ pass
src/infiagent/exceptions/__init__.py ADDED
File without changes
src/infiagent/exceptions/exceptions.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class DependencyException(Exception):
2
+ pass
3
+
4
+
5
+ class InputErrorException(Exception):
6
+ pass
7
+
8
+
9
+ class InternalErrorException(Exception):
10
+ pass
11
+
12
+
13
+ class DatabaseException(DependencyException):
14
+ def __init__(self, message, *args: object):
15
+ super().__init__(message, *args)
16
+
17
+
18
+ class SandboxException(DependencyException):
19
+ def __init__(self, message, *args: object):
20
+ super().__init__(message, *args)
21
+
22
+
23
+ class LLMException(DependencyException):
24
+ def __init__(self, message, *args: object):
25
+ super().__init__(message, *args)
26
+
27
+
28
+ class ModelMaxIterationsException(DependencyException):
29
+ def __init__(self, message, *args: object):
30
+ super().__init__(message, *args)
31
+
32
+
33
+ class InvalidConfigException(InputErrorException):
34
+ def __init__(self, message, *args: object):
35
+ super().__init__(message, *args)
36
+
37
+
38
+ class SandBoxFileUploadException(SandboxException):
39
+ def __init__(self, message, *args: object):
40
+ super().__init__(message, *args)
41
+
42
+
43
+ class PluginException(DependencyException):
44
+ def __init__(self, message, *args: object):
45
+ super().__init__(message, *args)
46
+
src/infiagent/llm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .client.openai import *
2
+ from .client.azure_openai import *
3
+ from .client.opt import *
4
+ from .client.llama import *
5
+ from .base_llm import *
src/infiagent/llm/base_llm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ from ..exceptions.exceptions import InputErrorException
4
+ from ..schemas import BaseCompletion
5
+
6
+
7
+ class BaseLLM(ABC):
8
+
9
+ def __init__(self, model_name: str, params: dict, **kwargs):
10
+ self.__model_name = model_name
11
+ self.__params = params
12
+
13
+ @classmethod
14
+ async def create(cls, config_data: dict):
15
+ pass
16
+
17
+ @property
18
+ def model_name(self) -> str:
19
+ return self.__model_name
20
+
21
+ @model_name.setter
22
+ def model_name(self, model_name):
23
+ if model_name is None:
24
+ raise InputErrorException("Invalid model_name {}".format(model_name))
25
+ self.__model_name = model_name
26
+
27
+ @property
28
+ def params(self) -> dict:
29
+ return self.__params
30
+
31
+ def completion(self, prompt) -> BaseCompletion:
32
+ pass
33
+
34
+ async def async_completion(self, prompt) -> BaseCompletion:
35
+ pass
36
+
src/infiagent/llm/client/__init__.py ADDED
File without changes
src/infiagent/llm/client/azure_openai.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from abc import ABC
5
+ from typing import Callable, List
6
+
7
+ import openai
8
+ from tenacity import ( # for exponential backoff
9
+ before_sleep_log,
10
+ retry,
11
+ stop_after_attempt,
12
+ wait_random_exponential,
13
+ )
14
+
15
+ from ..base_llm import BaseLLM
16
+ from ...schemas import *
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ MAX_PROMPT_LENGTH = 7000
21
+
22
+
23
+ @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(100), reraise=True,
24
+ before_sleep=before_sleep_log(logger, logging.WARNING))
25
+ def chatcompletion_with_backoff(**kwargs):
26
+ return openai.ChatCompletion.create(**kwargs)
27
+
28
+
29
+ @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(100), reraise=True,
30
+ before_sleep=before_sleep_log(logger, logging.WARNING))
31
+ async def async_chatcompletion_with_backoff(**kwargs):
32
+ async def _internal_coroutine():
33
+ return await openai.ChatCompletion.acreate(**kwargs)
34
+
35
+ return await _internal_coroutine()
36
+
37
+
38
+ class AzureOpenAIGPTClient(BaseLLM, ABC):
39
+ """
40
+ Wrapper class for OpenAI GPT API collections.
41
+
42
+ :param model_name: The name of the model to use.
43
+ :type model_name: str
44
+ :param params: The parameters for the model.
45
+ :type params: AzureOpenAIParamModel
46
+ """
47
+
48
+ model_name: str
49
+ params: AzureOpenAIParamModel = AzureOpenAIParamModel()
50
+
51
+ def __init__(self, **data):
52
+ super().__init__(**data)
53
+ openai.api_key = os.environ.get("OPENAI_API_KEY", "")
54
+ openai.api_type = "azure"
55
+ openai.api_base = "https://search.bytedance.net/gpt/openapi/online/v2/crawl"
56
+ openai.api_version = "2023-06-01-preview"
57
+
58
+ @classmethod
59
+ async def create(cls, config_data):
60
+ return AzureOpenAIGPTClient(**config_data)
61
+
62
+ def get_model_name(self) -> str:
63
+ return self.model_name
64
+
65
+ def get_model_param(self) -> AzureOpenAIParamModel:
66
+ return self.params
67
+
68
+ def completion(self, prompt: str, **kwargs) -> BaseCompletion:
69
+ """
70
+ Completion method for OpenAI GPT API.
71
+
72
+ :param prompt: The prompt to use for completion.
73
+ :type prompt: str
74
+ :param kwargs: Additional keyword arguments.
75
+ :type kwargs: dict
76
+ :return: BaseCompletion object.
77
+ :rtype: BaseCompletion
78
+
79
+ """
80
+
81
+ response = chatcompletion_with_backoff(
82
+ engine=self.get_model_name(), # GPT-4
83
+ messages=[
84
+ {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
85
+ ],
86
+ timeout=1000,
87
+ temperature=self.params.temperature,
88
+ max_tokens=self.params.max_tokens,
89
+ top_p=self.params.top_p,
90
+ frequency_penalty=self.params.frequency_penalty,
91
+ presence_penalty=self.params.presence_penalty,
92
+ **kwargs
93
+ )
94
+
95
+ return BaseCompletion(state="success",
96
+ content=response.choices[0].message["content"],
97
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
98
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
99
+
100
+ async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
101
+ """
102
+ Completion method for OpenAI GPT API.
103
+
104
+ :param prompt: The prompt to use for completion.
105
+ :type prompt: str
106
+ :param kwargs: Additional keyword arguments.
107
+ :type kwargs: dict
108
+ :return: BaseCompletion object.
109
+ :rtype: BaseCompletion
110
+
111
+ """
112
+ response = await async_chatcompletion_with_backoff(
113
+ engine=self.get_model_name(),
114
+ messages=[
115
+ {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
116
+ ],
117
+ timeout=1000,
118
+ temperature=self.params.temperature,
119
+ max_tokens=self.params.max_tokens,
120
+ top_p=self.params.top_p,
121
+ frequency_penalty=self.params.frequency_penalty,
122
+ presence_penalty=self.params.presence_penalty,
123
+ **kwargs
124
+ )
125
+
126
+ return BaseCompletion(state="success",
127
+ content=response.choices[0].message["content"],
128
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
129
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
130
+
131
+ def chat_completion(self, message: List[dict]) -> ChatCompletion:
132
+ """
133
+ Chat completion method for OpenAI GPT API.
134
+
135
+ :param message: The message to use for completion.
136
+ :type message: List[dict]
137
+ :return: ChatCompletion object.
138
+ :rtype: ChatCompletion
139
+ """
140
+ try:
141
+ response = openai.ChatCompletion.create(
142
+ engine=self.get_model_name(), # GPT-4
143
+ messages=message,
144
+ timeout=1000,
145
+ )
146
+
147
+ return ChatCompletion(
148
+ state="success",
149
+ role=response.choices[0].message["role"],
150
+ content=response.choices[0].message["content"],
151
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
152
+ completion_token=response.get("usage", {}).get("completion_tokens", 0),
153
+ )
154
+ except Exception as exception:
155
+ print("Exception:", exception)
156
+ return ChatCompletion(state="error", content=exception)
157
+
158
+ def stream_chat_completion(self, message: List[dict], **kwargs):
159
+ """
160
+ Stream output chat completion for OpenAI GPT API.
161
+
162
+ :param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
163
+ :type message: List[dict]
164
+ :param kwargs: Additional keyword arguments.
165
+ :type kwargs: dict
166
+ :return: ChatCompletion object.
167
+ :rtype: ChatCompletion
168
+ """
169
+ try:
170
+ response = openai.ChatCompletion.create(
171
+ engine=self.get_model_name(), # GPT-4
172
+ messages=message,
173
+ timeout=1000,
174
+ **kwargs,
175
+ )
176
+
177
+ role = next(response).choices[0].delta["role"]
178
+ messages = []
179
+ ## TODO: Calculate prompt_token and for stream mode
180
+ for resp in response:
181
+ messages.append(resp.choices[0].delta.get("content", ""))
182
+ yield ChatCompletion(
183
+ state="success",
184
+ role=role,
185
+ content=messages[-1],
186
+ prompt_token=0,
187
+ completion_token=0,
188
+ )
189
+ except Exception as exception:
190
+ print("Exception:", exception)
191
+ return ChatCompletion(state="error", content=exception)
192
+
193
+ def function_chat_completion(
194
+ self,
195
+ message: List[dict],
196
+ function_map: Dict[str, Callable],
197
+ function_schema: List[Dict],
198
+ ) -> ChatCompletionWithHistory:
199
+ """
200
+ Chat completion method for OpenAI GPT API.
201
+
202
+ :param message: The message to use for completion.
203
+ :type message: List[dict]
204
+ :param function_map: The function map to use for completion.
205
+ :type function_map: Dict[str, Callable]
206
+ :param function_schema: The function schema to use for completion.
207
+ :type function_schema: List[Dict]
208
+ :return: ChatCompletionWithHistory object.
209
+ :rtype: ChatCompletionWithHistory
210
+ """
211
+ assert len(function_schema) == len(function_map)
212
+ try:
213
+ response = openai.ChatCompletion.create(
214
+ engine=self.get_model_name(), # GPT-4
215
+ messages=message,
216
+ functions=function_schema,
217
+ timeout=1000,
218
+ )
219
+ # response = openai.ChatCompletion.create(
220
+ # n=self.params.n,
221
+ # model=self.model_name,
222
+ # messages=message,
223
+ # functions=function_schema,
224
+ # temperature=self.params.temperature,
225
+ # max_tokens=self.params.max_tokens,
226
+ # top_p=self.params.top_p,
227
+ # frequency_penalty=self.params.frequency_penalty,
228
+ # presence_penalty=self.params.presence_penalty,
229
+ # )
230
+ response_message = response.choices[0]["message"]
231
+
232
+ if response_message.get("function_call"):
233
+ function_name = response_message["function_call"]["name"]
234
+ fuction_to_call = function_map[function_name]
235
+ function_args = json.loads(
236
+ response_message["function_call"]["arguments"]
237
+ )
238
+ function_response = fuction_to_call(**function_args)
239
+
240
+ # Postprocess function response
241
+ if isinstance(function_response, str):
242
+ plugin_cost = 0
243
+ plugin_token = 0
244
+ elif isinstance(function_response, AgentOutput):
245
+ plugin_cost = function_response.cost
246
+ plugin_token = function_response.token_usage
247
+ function_response = function_response.output
248
+ else:
249
+ raise Exception(
250
+ "Invalid tool response type. Must be on of [AgentOutput, str]"
251
+ )
252
+
253
+ message.append(dict(response_message))
254
+ message.append(
255
+ {
256
+ "role": "function",
257
+ "name": function_name,
258
+ "content": function_response,
259
+ }
260
+ )
261
+ second_response = openai.ChatCompletion.create(
262
+ model=self.get_model_name(),
263
+ messages=message,
264
+ )
265
+ message.append(dict(second_response.choices[0].message))
266
+ return ChatCompletionWithHistory(
267
+ state="success",
268
+ role=second_response.choices[0].message["role"],
269
+ content=second_response.choices[0].message["content"],
270
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0)
271
+ + second_response.get("usage", {}).get("prompt_tokens", 0),
272
+ completion_token=response.get("usage", {}).get(
273
+ "completion_tokens", 0
274
+ )
275
+ + second_response.get("usage", {}).get("completion_tokens", 0),
276
+ message_scratchpad=message,
277
+ plugin_cost=plugin_cost,
278
+ plugin_token=plugin_token,
279
+ )
280
+ else:
281
+ message.append(dict(response_message))
282
+ return ChatCompletionWithHistory(
283
+ state="success",
284
+ role=response.choices[0].message["role"],
285
+ content=response.choices[0].message["content"],
286
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
287
+ completion_token=response.get("usage", {}).get(
288
+ "completion_tokens", 0
289
+ ),
290
+ message_scratchpad=message,
291
+ )
292
+
293
+ except Exception as exception:
294
+ print("Exception:", exception)
295
+ return ChatCompletionWithHistory(state="error", content=str(exception))
296
+
297
+ def function_chat_stream_completion(
298
+ self,
299
+ message: List[dict],
300
+ function_map: Dict[str, Callable],
301
+ function_schema: List[Dict],
302
+ ) -> ChatCompletionWithHistory:
303
+ assert len(function_schema) == len(function_map)
304
+ try:
305
+ response = openai.ChatCompletion.create(
306
+ n=self.params.n,
307
+ model=self.get_model_name(),
308
+ messages=message,
309
+ functions=function_schema,
310
+ temperature=self.params.temperature,
311
+ max_tokens=self.params.max_tokens,
312
+ top_p=self.params.top_p,
313
+ frequency_penalty=self.params.frequency_penalty,
314
+ presence_penalty=self.params.presence_penalty,
315
+ stream=True,
316
+ )
317
+ tmp = next(response)
318
+ role = tmp.choices[0].delta["role"]
319
+ _type = (
320
+ "function_call"
321
+ if tmp.choices[0].delta["content"] is None
322
+ else "content"
323
+ )
324
+ if _type == "function_call":
325
+ name = tmp.choices[0].delta["function_call"]["name"]
326
+ yield _type, ChatCompletionWithHistory(
327
+ state="success",
328
+ role=role,
329
+ content="{" + f'"name":"{name}", "arguments":',
330
+ message_scratchpad=message,
331
+ )
332
+ for resp in response:
333
+ # print(resp)
334
+ content = resp.choices[0].delta.get(_type, "")
335
+ if isinstance(content, dict):
336
+ content = content["arguments"]
337
+ yield _type, ChatCompletionWithHistory(
338
+ state="success",
339
+ role=role,
340
+ content=content,
341
+ message_scratchpad=message,
342
+ )
343
+
344
+ except Exception as e:
345
+ logger.error(f"Failed to get response {str(e)}", exc_info=True)
346
+ raise e
src/infiagent/llm/client/llama.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from abc import ABC
5
+ from typing import Callable, List
6
+
7
+ import openai
8
+ from tenacity import ( # for exponential backoff
9
+ before_sleep_log,
10
+ retry,
11
+ stop_after_attempt,
12
+ wait_random_exponential,
13
+ )
14
+
15
+ from ..base_llm import BaseLLM
16
+ from ...schemas import *
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ MAX_PROMPT_LENGTH = 4096
21
+
22
+
23
+ @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(5), reraise=True,
24
+ before_sleep=before_sleep_log(logger, logging.WARNING))
25
+ def chatcompletion_with_backoff(**kwargs):
26
+ return openai.ChatCompletion.create(**kwargs)
27
+
28
+
29
+ @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(5), reraise=True,
30
+ before_sleep=before_sleep_log(logger, logging.WARNING))
31
+ async def async_chatcompletion_with_backoff(**kwargs):
32
+ async def _internal_coroutine():
33
+ return await openai.ChatCompletion.acreate(**kwargs)
34
+
35
+ return await _internal_coroutine()
36
+
37
+
38
+ class LlamaOpenAIClient(BaseLLM, ABC):
39
+ """
40
+ Wrapper class for OpenAI GPT API collections.
41
+
42
+ :param model_name: The name of the model to use.
43
+ :type model_name: str
44
+ :param params: The parameters for the model.
45
+ :type params: LlamaParamModel
46
+ """
47
+
48
+ model_name: str
49
+ params: LlamaParamModel = LlamaParamModel()
50
+
51
+ def __init__(self, **data):
52
+ super().__init__(**data)
53
+ openai.api_key = ""
54
+ openai.api_base = "http://0.0.0.0:9729/v1"
55
+
56
+ @classmethod
57
+ async def create(cls, config_data):
58
+ return LlamaOpenAIClient(**config_data)
59
+
60
+ def get_model_name(self) -> str:
61
+ return self.model_name
62
+
63
+ def get_model_param(self) -> LlamaParamModel:
64
+ return self.params
65
+
66
+ def completion(self, prompt: str, **kwargs) -> BaseCompletion:
67
+ """
68
+ Completion method for OpenAI GPT API.
69
+
70
+ :param prompt: The prompt to use for completion.
71
+ :type prompt: str
72
+ :param kwargs: Additional keyword arguments.
73
+ :type kwargs: dict
74
+ :return: BaseCompletion object.
75
+ :rtype: BaseCompletion
76
+ """
77
+ response = chatcompletion_with_backoff(
78
+ model=self.model_name,
79
+ messages=[
80
+ {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
81
+ ],
82
+ timeout=1000,
83
+ # temperature=self.params.temperature,
84
+ # max_tokens=self.params.max_tokens,
85
+ # top_p=self.params.top_p,
86
+ # frequency_penalty=self.params.frequency_penalty,
87
+ # presence_penalty=self.params.presence_penalty,
88
+ # stop=["<|im_end|>", "<|endoftext|>"],
89
+ **kwargs
90
+ )
91
+
92
+ return BaseCompletion(state="success",
93
+ content=response.choices[0].message["content"],
94
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
95
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
96
+
97
+ async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
98
+ """
99
+ Completion method for OpenAI GPT API.
100
+
101
+ :param prompt: The prompt to use for completion.
102
+ :type prompt: str
103
+ :param kwargs: Additional keyword arguments.
104
+ :type kwargs: dict
105
+ :return: BaseCompletion object.
106
+ :rtype: BaseCompletion
107
+
108
+ """
109
+ response = await async_chatcompletion_with_backoff(
110
+ model=self.model_name,
111
+ messages=[
112
+ {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
113
+ ],
114
+ timeout=1000,
115
+ #temperature=0.2,
116
+ #max_tokens=4096,
117
+ #top_p=0.9,
118
+ #frequency_penalty=self.params.frequency_penalty,
119
+ #presence_penalty=self.params.presence_penalty,
120
+ # stop=["<|im_end|>", "<|endoftext|>"],
121
+ **kwargs
122
+ )
123
+
124
+ return BaseCompletion(state="success",
125
+ content=response.choices[0].message["content"],
126
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
127
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
128
+
129
+ def chat_completion(self, message: List[dict]) -> ChatCompletion:
130
+ """
131
+ Chat completion method for OpenAI GPT API.
132
+
133
+ :param message: The message to use for completion.
134
+ :type message: List[dict]
135
+ :return: ChatCompletion object.
136
+ :rtype: ChatCompletion
137
+ """
138
+ try:
139
+ response = openai.ChatCompletion.create(
140
+ n=self.params.n,
141
+ model=self.model_name,
142
+ timeout=1000,
143
+ messages=message,
144
+ temperature=self.params.temperature,
145
+ max_tokens=self.params.max_tokens,
146
+ top_p=self.params.top_p,
147
+ frequency_penalty=self.params.frequency_penalty,
148
+ presence_penalty=self.params.presence_penalty,
149
+ )
150
+ return ChatCompletion(
151
+ state="success",
152
+ role=response.choices[0].message["role"],
153
+ content=response.choices[0].message["content"],
154
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
155
+ completion_token=response.get("usage", {}).get("completion_tokens", 0),
156
+ )
157
+ except Exception as exception:
158
+ print("Exception:", exception)
159
+ return ChatCompletion(state="error", content=exception)
160
+
161
+ def stream_chat_completion(self, message: List[dict], **kwargs):
162
+ """
163
+ Stream output chat completion for OpenAI GPT API.
164
+
165
+ :param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
166
+ :type message: List[dict]
167
+ :param kwargs: Additional keyword arguments.
168
+ :type kwargs: dict
169
+ :return: ChatCompletion object.
170
+ :rtype: ChatCompletion
171
+ """
172
+ try:
173
+ # response = openai.ChatCompletion.create(
174
+ # engine=self.get_model_name(), # GPT-4
175
+ # messages=message,
176
+ # timeout=1000,
177
+ # **kwargs,
178
+ # )
179
+ response = openai.ChatCompletion.create(
180
+ n=self.params.n,
181
+ model=self.model_name,
182
+ messages=message,
183
+ temperature=self.params.temperature,
184
+ max_tokens=self.params.max_tokens,
185
+ top_p=self.params.top_p,
186
+ frequency_penalty=self.params.frequency_penalty,
187
+ presence_penalty=self.params.presence_penalty,
188
+ stream=True,
189
+ **kwargs
190
+ )
191
+ role = next(response).choices[0].delta["role"]
192
+ messages = []
193
+ ## TODO: Calculate prompt_token and for stream mode
194
+ for resp in response:
195
+ messages.append(resp.choices[0].delta.get("content", ""))
196
+ yield ChatCompletion(
197
+ state="success",
198
+ role=role,
199
+ content=messages[-1],
200
+ prompt_token=0,
201
+ completion_token=0,
202
+ )
203
+ except Exception as exception:
204
+ print("Exception:", exception)
205
+ return ChatCompletion(state="error", content=exception)
206
+
207
+ def function_chat_completion(
208
+ self,
209
+ message: List[dict],
210
+ function_map: Dict[str, Callable],
211
+ function_schema: List[Dict],
212
+ ) -> ChatCompletionWithHistory:
213
+ """
214
+ Chat completion method for OpenAI GPT API.
215
+
216
+ :param message: The message to use for completion.
217
+ :type message: List[dict]
218
+ :param function_map: The function map to use for completion.
219
+ :type function_map: Dict[str, Callable]
220
+ :param function_schema: The function schema to use for completion.
221
+ :type function_schema: List[Dict]
222
+ :return: ChatCompletionWithHistory object.
223
+ :rtype: ChatCompletionWithHistory
224
+ """
225
+ assert len(function_schema) == len(function_map)
226
+ try:
227
+ # response = openai.ChatCompletion.create(
228
+ # engine=self.get_model_name(), # GPT-4
229
+ # messages=message,
230
+ # functions=function_schema,
231
+ # timeout=1000,
232
+ # )
233
+ response = openai.ChatCompletion.create(
234
+ n=self.params.n,
235
+ model=self.model_name,
236
+ messages=message,
237
+ functions=function_schema,
238
+ temperature=self.params.temperature,
239
+ max_tokens=self.params.max_tokens,
240
+ top_p=self.params.top_p,
241
+ frequency_penalty=self.params.frequency_penalty,
242
+ presence_penalty=self.params.presence_penalty,
243
+ )
244
+ response_message = response.choices[0]["message"]
245
+
246
+ if response_message.get("function_call"):
247
+ function_name = response_message["function_call"]["name"]
248
+ fuction_to_call = function_map[function_name]
249
+ function_args = json.loads(
250
+ response_message["function_call"]["arguments"]
251
+ )
252
+ function_response = fuction_to_call(**function_args)
253
+
254
+ # Postprocess function response
255
+ if isinstance(function_response, str):
256
+ plugin_cost = 0
257
+ plugin_token = 0
258
+ elif isinstance(function_response, AgentOutput):
259
+ plugin_cost = function_response.cost
260
+ plugin_token = function_response.token_usage
261
+ function_response = function_response.output
262
+ else:
263
+ raise Exception(
264
+ "Invalid tool response type. Must be on of [AgentOutput, str]"
265
+ )
266
+
267
+ message.append(dict(response_message))
268
+ message.append(
269
+ {
270
+ "role": "function",
271
+ "name": function_name,
272
+ "content": function_response,
273
+ }
274
+ )
275
+ second_response = openai.ChatCompletion.create(
276
+ model=self.get_model_name(),
277
+ messages=message,
278
+ )
279
+ message.append(dict(second_response.choices[0].message))
280
+ return ChatCompletionWithHistory(
281
+ state="success",
282
+ role=second_response.choices[0].message["role"],
283
+ content=second_response.choices[0].message["content"],
284
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0)
285
+ + second_response.get("usage", {}).get("prompt_tokens", 0),
286
+ completion_token=response.get("usage", {}).get(
287
+ "completion_tokens", 0
288
+ )
289
+ + second_response.get("usage", {}).get("completion_tokens", 0),
290
+ message_scratchpad=message,
291
+ plugin_cost=plugin_cost,
292
+ plugin_token=plugin_token,
293
+ )
294
+ else:
295
+ message.append(dict(response_message))
296
+ return ChatCompletionWithHistory(
297
+ state="success",
298
+ role=response.choices[0].message["role"],
299
+ content=response.choices[0].message["content"],
300
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
301
+ completion_token=response.get("usage", {}).get(
302
+ "completion_tokens", 0
303
+ ),
304
+ message_scratchpad=message,
305
+ )
306
+
307
+ except Exception as exception:
308
+ print("Exception:", exception)
309
+ return ChatCompletionWithHistory(state="error", content=str(exception))
310
+
311
+ def function_chat_stream_completion(
312
+ self,
313
+ message: List[dict],
314
+ function_map: Dict[str, Callable],
315
+ function_schema: List[Dict],
316
+ ) -> ChatCompletionWithHistory:
317
+ assert len(function_schema) == len(function_map)
318
+ try:
319
+ response = openai.ChatCompletion.create(
320
+ n=self.params.n,
321
+ model=self.get_model_name(),
322
+ messages=message,
323
+ functions=function_schema,
324
+ temperature=self.params.temperature,
325
+ max_tokens=self.params.max_tokens,
326
+ top_p=self.params.top_p,
327
+ frequency_penalty=self.params.frequency_penalty,
328
+ presence_penalty=self.params.presence_penalty,
329
+ stream=True,
330
+ )
331
+ tmp = next(response)
332
+ role = tmp.choices[0].delta["role"]
333
+ _type = (
334
+ "function_call"
335
+ if tmp.choices[0].delta["content"] is None
336
+ else "content"
337
+ )
338
+ if _type == "function_call":
339
+ name = tmp.choices[0].delta["function_call"]["name"]
340
+ yield _type, ChatCompletionWithHistory(
341
+ state="success",
342
+ role=role,
343
+ content="{" + f'"name":"{name}", "arguments":',
344
+ message_scratchpad=message,
345
+ )
346
+ for resp in response:
347
+ # print(resp)
348
+ content = resp.choices[0].delta.get(_type, "")
349
+ if isinstance(content, dict):
350
+ content = content["arguments"]
351
+ yield _type, ChatCompletionWithHistory(
352
+ state="success",
353
+ role=role,
354
+ content=content,
355
+ message_scratchpad=message,
356
+ )
357
+
358
+ # result = ''.join(messages)
359
+ # if _type == "function_call":
360
+ # result = json.loads(result)
361
+ # function_name = result["name"]
362
+ # fuction_to_call = function_map[function_name]
363
+ # function_args = result["arguments"]
364
+ # function_response = fuction_to_call(**function_args)
365
+ #
366
+ # # Postprocess function response
367
+ # if isinstance(function_response, AgentOutput):
368
+ # function_response = function_response.output
369
+ # message.append({"role": "function",
370
+ # "name": function_name,
371
+ # "content": function_response})
372
+ # second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema)
373
+ # message.append(dict(second_response.choices[0].message))
374
+
375
+ except Exception as e:
376
+ logger.error(f"Failed to get response {str(e)}", exc_info=True)
377
+ raise e
src/infiagent/llm/client/openai.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from abc import ABC
4
+ from typing import Callable, List
5
+
6
+ import openai
7
+
8
+ from ..base_llm import BaseLLM
9
+ from ...schemas import *
10
+
11
+
12
+ class OpenAIGPTClient(BaseLLM, ABC):
13
+ """
14
+ Wrapper class for OpenAI GPT API collections.
15
+
16
+ :param model_name: The name of the model to use.
17
+ :type model_name: str
18
+ :param params: The parameters for the model.
19
+ :type params: OpenAIParamModel
20
+ """
21
+ model_name: str
22
+ params: OpenAIParamModel = OpenAIParamModel()
23
+
24
+ def __init__(self, **data):
25
+ super().__init__(**data)
26
+ openai.api_key = os.environ.get("OPENAI_API_KEY", "")
27
+
28
+ @classmethod
29
+ async def create(cls, config_data):
30
+ return OpenAIGPTClient(**config_data)
31
+
32
+ def get_model_name(self) -> str:
33
+ return self.model_name
34
+
35
+ def get_model_param(self) -> OpenAIParamModel:
36
+ return self.params
37
+
38
+ def completion(self, prompt: str, **kwargs) -> BaseCompletion:
39
+ """
40
+ Completion method for OpenAI GPT API.
41
+
42
+ :param prompt: The prompt to use for completion.
43
+ :type prompt: str
44
+ :param kwargs: Additional keyword arguments.
45
+ :type kwargs: dict
46
+ :return: BaseCompletion object.
47
+ :rtype: BaseCompletion
48
+
49
+ """
50
+ try:
51
+ #TODO any full parameters support
52
+ response = openai.ChatCompletion.create(
53
+ # n=self.params['n'],
54
+ engine=self.model_name,
55
+ messages=[{"role": "user", "content": prompt}],
56
+ temperature=self.params['temperature'],
57
+ max_tokens=self.params['max_tokens'],
58
+ top_p=self.params['top_p'],
59
+ # frequency_penalty=self.params.frequency_penalty,
60
+ # presence_penalty=self.params.presence_penalty,
61
+ **kwargs
62
+ )
63
+ return BaseCompletion(state="success",
64
+ content=response.choices[0].message["content"],
65
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
66
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
67
+ except Exception as exception:
68
+ print("Exception:", exception)
69
+ return BaseCompletion(state="error", content=exception)
70
+
71
+ async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
72
+ """
73
+ Async Completion method for OpenAI GPT API.
74
+
75
+ :param prompt: The prompt to use for completion.
76
+ :type prompt: str
77
+ :param kwargs: Additional keyword arguments.
78
+ :type kwargs: dict
79
+ :return: BaseCompletion object.
80
+ :rtype: BaseCompletion
81
+
82
+ """
83
+ try:
84
+ response = await openai.ChatCompletion.acreate(
85
+ model=self.model_name,
86
+ messages=[{"role": "user", "content": prompt}],
87
+ temperature=self.params['temperature'],
88
+ max_tokens=self.params['max_tokens'],
89
+ top_p=self.params['top_p'],
90
+ # frequency_penalty=self.params.frequency_penalty,
91
+ # presence_penalty=self.params.presence_penalty,
92
+ **kwargs
93
+ )
94
+ return BaseCompletion(state="success",
95
+ content=response.choices[0].message["content"],
96
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
97
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
98
+ except Exception as exception:
99
+ print("Exception:", exception)
100
+ return BaseCompletion(state="error", content=exception)
101
+
102
+
103
+ def chat_completion(self, message: List[dict]) -> ChatCompletion:
104
+ """
105
+ Chat completion method for OpenAI GPT API.
106
+
107
+ :param message: The message to use for completion.
108
+ :type message: List[dict]
109
+ :return: ChatCompletion object.
110
+ :rtype: ChatCompletion
111
+ """
112
+ try:
113
+ response = openai.ChatCompletion.create(
114
+ n=self.params.n,
115
+ model=self.model_name,
116
+ messages=message,
117
+ temperature=self.params.temperature,
118
+ max_tokens=self.params.max_tokens,
119
+ top_p=self.params.top_p,
120
+ frequency_penalty=self.params.frequency_penalty,
121
+ presence_penalty=self.params.presence_penalty,
122
+ )
123
+ return ChatCompletion(state="success",
124
+ role=response.choices[0].message["role"],
125
+ content=response.choices[0].message["content"],
126
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
127
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
128
+ except Exception as exception:
129
+ print("Exception:", exception)
130
+ return ChatCompletion(state="error", content=exception)
131
+
132
+ def stream_chat_completion(self, message: List[dict], **kwargs):
133
+ """
134
+ Stream output chat completion for OpenAI GPT API.
135
+
136
+ :param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
137
+ :type message: List[dict]
138
+ :param kwargs: Additional keyword arguments.
139
+ :type kwargs: dict
140
+ :return: ChatCompletion object.
141
+ :rtype: ChatCompletion
142
+ """
143
+ try:
144
+ response = openai.ChatCompletion.create(
145
+ n=self.params.n,
146
+ model=self.model_name,
147
+ messages=message,
148
+ temperature=self.params.temperature,
149
+ max_tokens=self.params.max_tokens,
150
+ top_p=self.params.top_p,
151
+ frequency_penalty=self.params.frequency_penalty,
152
+ presence_penalty=self.params.presence_penalty,
153
+ stream=True,
154
+ **kwargs
155
+ )
156
+ role = next(response).choices[0].delta["role"]
157
+ messages = []
158
+ ## TODO: Calculate prompt_token and for stream mode
159
+ for resp in response:
160
+ messages.append(resp.choices[0].delta.get("content", ""))
161
+ yield ChatCompletion(state="success",
162
+ role=role,
163
+ content=messages[-1],
164
+ prompt_token=0,
165
+ completion_token=0)
166
+ except Exception as exception:
167
+ print("Exception:", exception)
168
+ return ChatCompletion(state="error", content=exception)
169
+
170
+ def function_chat_completion(self, message: List[dict],
171
+ function_map: Dict[str, Callable],
172
+ function_schema: List[Dict]) -> ChatCompletionWithHistory:
173
+ """
174
+ Chat completion method for OpenAI GPT API.
175
+
176
+ :param message: The message to use for completion.
177
+ :type message: List[dict]
178
+ :param function_map: The function map to use for completion.
179
+ :type function_map: Dict[str, Callable]
180
+ :param function_schema: The function schema to use for completion.
181
+ :type function_schema: List[Dict]
182
+ :return: ChatCompletionWithHistory object.
183
+ :rtype: ChatCompletionWithHistory
184
+ """
185
+ assert len(function_schema) == len(function_map)
186
+ try:
187
+ response = openai.ChatCompletion.create(
188
+ n=self.params.n,
189
+ model=self.model_name,
190
+ messages=message,
191
+ functions=function_schema,
192
+ temperature=self.params.temperature,
193
+ max_tokens=self.params.max_tokens,
194
+ top_p=self.params.top_p,
195
+ frequency_penalty=self.params.frequency_penalty,
196
+ presence_penalty=self.params.presence_penalty,
197
+ )
198
+ response_message = response.choices[0]["message"]
199
+
200
+ if response_message.get("function_call"):
201
+ function_name = response_message["function_call"]["name"]
202
+ fuction_to_call = function_map[function_name]
203
+ function_args = json.loads(response_message["function_call"]["arguments"])
204
+ function_response = fuction_to_call(**function_args)
205
+
206
+ # Postprocess function response
207
+ if isinstance(function_response, str):
208
+ plugin_cost = 0
209
+ plugin_token = 0
210
+ elif isinstance(function_response, AgentOutput):
211
+ plugin_cost = function_response.cost
212
+ plugin_token = function_response.token_usage
213
+ function_response = function_response.output
214
+ else:
215
+ raise Exception("Invalid tool response type. Must be on of [AgentOutput, str]")
216
+
217
+ message.append(dict(response_message))
218
+ message.append({"role": "function",
219
+ "name": function_name,
220
+ "content": function_response})
221
+ second_response = openai.ChatCompletion.create(
222
+ model=self.model_name,
223
+ messages=message,
224
+ )
225
+ message.append(dict(second_response.choices[0].message))
226
+ return ChatCompletionWithHistory(state="success",
227
+ role=second_response.choices[0].message["role"],
228
+ content=second_response.choices[0].message["content"],
229
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0) +
230
+ second_response.get("usage", {}).get("prompt_tokens", 0),
231
+ completion_token=response.get("usage", {}).get("completion_tokens", 0) +
232
+ second_response.get("usage", {}).get("completion_tokens", 0),
233
+ message_scratchpad=message,
234
+ plugin_cost=plugin_cost,
235
+ plugin_token=plugin_token,
236
+ )
237
+ else:
238
+ message.append(dict(response_message))
239
+ return ChatCompletionWithHistory(state="success",
240
+ role=response.choices[0].message["role"],
241
+ content=response.choices[0].message["content"],
242
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
243
+ completion_token=response.get("usage", {}).get("completion_tokens", 0),
244
+ message_scratchpad=message)
245
+
246
+ except Exception as exception:
247
+ print("Exception:", exception)
248
+ return ChatCompletionWithHistory(state="error", content=str(exception))
249
+
250
+ def function_chat_stream_completion(self, message: List[dict],
251
+ function_map: Dict[str, Callable],
252
+ function_schema: List[Dict]) -> ChatCompletionWithHistory:
253
+ assert len(function_schema) == len(function_map)
254
+ try:
255
+ response = openai.ChatCompletion.create(
256
+ n=self.params.n,
257
+ model=self.model_name,
258
+ messages=message,
259
+ functions=function_schema,
260
+ temperature=self.params.temperature,
261
+ max_tokens=self.params.max_tokens,
262
+ top_p=self.params.top_p,
263
+ frequency_penalty=self.params.frequency_penalty,
264
+ presence_penalty=self.params.presence_penalty,
265
+ stream=True
266
+ )
267
+ tmp = next(response)
268
+ role = tmp.choices[0].delta["role"]
269
+ _type = "function_call" if tmp.choices[0].delta["content"] is None else "content"
270
+ if _type == "function_call":
271
+ name = tmp.choices[0].delta['function_call']['name']
272
+ yield _type, ChatCompletionWithHistory(state="success", role=role,
273
+ content="{" + f'"name":"{name}", "arguments":',
274
+ message_scratchpad=message)
275
+ for resp in response:
276
+ # print(resp)
277
+ content = resp.choices[0].delta.get(_type, "")
278
+ if isinstance(content, dict):
279
+ content = content['arguments']
280
+ yield _type, ChatCompletionWithHistory(state="success",
281
+ role=role,
282
+ content=content,
283
+ message_scratchpad=message)
284
+
285
+ # result = ''.join(messages)
286
+ # if _type == "function_call":
287
+ # result = json.loads(result)
288
+ # function_name = result["name"]
289
+ # fuction_to_call = function_map[function_name]
290
+ # function_args = result["arguments"]
291
+ # function_response = fuction_to_call(**function_args)
292
+ #
293
+ # # Postprocess function response
294
+ # if isinstance(function_response, AgentOutput):
295
+ # function_response = function_response.output
296
+ # message.append({"role": "function",
297
+ # "name": function_name,
298
+ # "content": function_response})
299
+ # second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema)
300
+ # message.append(dict(second_response.choices[0].message))
301
+
302
+
303
+ except Exception as exception:
304
+ raise exception
305
+ print("Exception:", exception)
306
+ return ChatCompletion(state="error", content=str(exception))
src/infiagent/llm/client/opt.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from abc import ABC
5
+ from typing import Callable, List
6
+
7
+ import openai
8
+ from tenacity import ( # for exponential backoff
9
+ before_sleep_log,
10
+ retry,
11
+ stop_after_attempt,
12
+ wait_random_exponential,
13
+ )
14
+
15
+ from ..base_llm import BaseLLM
16
+ from ...schemas import *
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ MAX_PROMPT_LENGTH = 7000
21
+
22
+
23
+ @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True,
24
+ before_sleep=before_sleep_log(logger, logging.WARNING))
25
+ def chatcompletion_with_backoff(**kwargs):
26
+ return openai.ChatCompletion.create(**kwargs)
27
+
28
+
29
+ @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True,
30
+ before_sleep=before_sleep_log(logger, logging.WARNING))
31
+ async def async_chatcompletion_with_backoff(**kwargs):
32
+ async def _internal_coroutine():
33
+ return await openai.ChatCompletion.acreate(**kwargs)
34
+
35
+ return await _internal_coroutine()
36
+
37
+
38
+ class OptOpenAIClient(BaseLLM, ABC):
39
+ """
40
+ Wrapper class for OpenAI GPT API collections.
41
+
42
+ :param model_name: The name of the model to use.
43
+ :type model_name: str
44
+ :param params: The parameters for the model.
45
+ :type params: OptParamModel
46
+ """
47
+
48
+ model_name: str
49
+ params: OptParamModel = OptParamModel()
50
+
51
+ def __init__(self, **data):
52
+ super().__init__(**data)
53
+ openai.api_key = "EMPTY"
54
+ openai.api_base = "http://localhost:8000/v1"
55
+
56
+ @classmethod
57
+ async def create(cls, config_data):
58
+ return OptOpenAIClient(**config_data)
59
+
60
+ def get_model_name(self) -> str:
61
+ return self.model_name
62
+
63
+ def get_model_param(self) -> OptParamModel:
64
+ return self.params
65
+
66
+ def completion(self, prompt: str, **kwargs) -> BaseCompletion:
67
+ """
68
+ Completion method for OpenAI GPT API.
69
+
70
+ :param prompt: The prompt to use for completion.
71
+ :type prompt: str
72
+ :param kwargs: Additional keyword arguments.
73
+ :type kwargs: dict
74
+ :return: BaseCompletion object.
75
+ :rtype: BaseCompletion
76
+
77
+ """
78
+
79
+ response = chatcompletion_with_backoff(
80
+ model=self.model_name,
81
+ # engine=self.get_model_name(), # GPT-4
82
+ messages=[
83
+ {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
84
+ ],
85
+ timeout=1000,
86
+ **kwargs
87
+ )
88
+
89
+ return BaseCompletion(state="success",
90
+ content=response.choices[0].message["content"],
91
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
92
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
93
+
94
+ async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
95
+ """
96
+ Completion method for OpenAI GPT API.
97
+
98
+ :param prompt: The prompt to use for completion.
99
+ :type prompt: str
100
+ :param kwargs: Additional keyword arguments.
101
+ :type kwargs: dict
102
+ :return: BaseCompletion object.
103
+ :rtype: BaseCompletion
104
+
105
+ """
106
+ response = await async_chatcompletion_with_backoff(
107
+ # engine=self.get_model_name(), # GPT-4
108
+ model=self.model_name,
109
+ messages=[
110
+ {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
111
+ ],
112
+ timeout=1000,
113
+ **kwargs
114
+ )
115
+
116
+ return BaseCompletion(state="success",
117
+ content=response.choices[0].message["content"],
118
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
119
+ completion_token=response.get("usage", {}).get("completion_tokens", 0))
120
+
121
+ def chat_completion(self, message: List[dict]) -> ChatCompletion:
122
+ """
123
+ Chat completion method for OpenAI GPT API.
124
+
125
+ :param message: The message to use for completion.
126
+ :type message: List[dict]
127
+ :return: ChatCompletion object.
128
+ :rtype: ChatCompletion
129
+ """
130
+ try:
131
+ # response = openai.ChatCompletion.create(
132
+ # engine=self.get_model_name(), # GPT-4
133
+ # messages=message,
134
+ # timeout=1000,
135
+ # )
136
+ response = openai.ChatCompletion.create(
137
+ n=self.params.n,
138
+ model=self.model_name,
139
+ messages=message,
140
+ temperature=self.params.temperature,
141
+ max_tokens=self.params.max_tokens,
142
+ top_p=self.params.top_p,
143
+ frequency_penalty=self.params.frequency_penalty,
144
+ presence_penalty=self.params.presence_penalty,
145
+ )
146
+ return ChatCompletion(
147
+ state="success",
148
+ role=response.choices[0].message["role"],
149
+ content=response.choices[0].message["content"],
150
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
151
+ completion_token=response.get("usage", {}).get("completion_tokens", 0),
152
+ )
153
+ except Exception as exception:
154
+ print("Exception:", exception)
155
+ return ChatCompletion(state="error", content=exception)
156
+
157
+ def stream_chat_completion(self, message: List[dict], **kwargs):
158
+ """
159
+ Stream output chat completion for OpenAI GPT API.
160
+
161
+ :param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
162
+ :type message: List[dict]
163
+ :param kwargs: Additional keyword arguments.
164
+ :type kwargs: dict
165
+ :return: ChatCompletion object.
166
+ :rtype: ChatCompletion
167
+ """
168
+ try:
169
+ # response = openai.ChatCompletion.create(
170
+ # engine=self.get_model_name(), # GPT-4
171
+ # messages=message,
172
+ # timeout=1000,
173
+ # **kwargs,
174
+ # )
175
+ response = openai.ChatCompletion.create(
176
+ n=self.params.n,
177
+ model=self.model_name,
178
+ messages=message,
179
+ temperature=self.params.temperature,
180
+ max_tokens=self.params.max_tokens,
181
+ top_p=self.params.top_p,
182
+ frequency_penalty=self.params.frequency_penalty,
183
+ presence_penalty=self.params.presence_penalty,
184
+ stream=True,
185
+ **kwargs
186
+ )
187
+ role = next(response).choices[0].delta["role"]
188
+ messages = []
189
+ ## TODO: Calculate prompt_token and for stream mode
190
+ for resp in response:
191
+ messages.append(resp.choices[0].delta.get("content", ""))
192
+ yield ChatCompletion(
193
+ state="success",
194
+ role=role,
195
+ content=messages[-1],
196
+ prompt_token=0,
197
+ completion_token=0,
198
+ )
199
+ except Exception as exception:
200
+ print("Exception:", exception)
201
+ return ChatCompletion(state="error", content=exception)
202
+
203
+ def function_chat_completion(
204
+ self,
205
+ message: List[dict],
206
+ function_map: Dict[str, Callable],
207
+ function_schema: List[Dict],
208
+ ) -> ChatCompletionWithHistory:
209
+ """
210
+ Chat completion method for OpenAI GPT API.
211
+
212
+ :param message: The message to use for completion.
213
+ :type message: List[dict]
214
+ :param function_map: The function map to use for completion.
215
+ :type function_map: Dict[str, Callable]
216
+ :param function_schema: The function schema to use for completion.
217
+ :type function_schema: List[Dict]
218
+ :return: ChatCompletionWithHistory object.
219
+ :rtype: ChatCompletionWithHistory
220
+ """
221
+ assert len(function_schema) == len(function_map)
222
+ try:
223
+ # response = openai.ChatCompletion.create(
224
+ # engine=self.get_model_name(), # GPT-4
225
+ # messages=message,
226
+ # functions=function_schema,
227
+ # timeout=1000,
228
+ # )
229
+ response = openai.ChatCompletion.create(
230
+ n=self.params.n,
231
+ model=self.model_name,
232
+ messages=message,
233
+ functions=function_schema,
234
+ temperature=self.params.temperature,
235
+ max_tokens=self.params.max_tokens,
236
+ top_p=self.params.top_p,
237
+ frequency_penalty=self.params.frequency_penalty,
238
+ presence_penalty=self.params.presence_penalty,
239
+ )
240
+ response_message = response.choices[0]["message"]
241
+
242
+ if response_message.get("function_call"):
243
+ function_name = response_message["function_call"]["name"]
244
+ fuction_to_call = function_map[function_name]
245
+ function_args = json.loads(
246
+ response_message["function_call"]["arguments"]
247
+ )
248
+ function_response = fuction_to_call(**function_args)
249
+
250
+ # Postprocess function response
251
+ if isinstance(function_response, str):
252
+ plugin_cost = 0
253
+ plugin_token = 0
254
+ elif isinstance(function_response, AgentOutput):
255
+ plugin_cost = function_response.cost
256
+ plugin_token = function_response.token_usage
257
+ function_response = function_response.output
258
+ else:
259
+ raise Exception(
260
+ "Invalid tool response type. Must be on of [AgentOutput, str]"
261
+ )
262
+
263
+ message.append(dict(response_message))
264
+ message.append(
265
+ {
266
+ "role": "function",
267
+ "name": function_name,
268
+ "content": function_response,
269
+ }
270
+ )
271
+ second_response = openai.ChatCompletion.create(
272
+ model=self.get_model_name(),
273
+ messages=message,
274
+ )
275
+ message.append(dict(second_response.choices[0].message))
276
+ return ChatCompletionWithHistory(
277
+ state="success",
278
+ role=second_response.choices[0].message["role"],
279
+ content=second_response.choices[0].message["content"],
280
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0)
281
+ + second_response.get("usage", {}).get("prompt_tokens", 0),
282
+ completion_token=response.get("usage", {}).get(
283
+ "completion_tokens", 0
284
+ )
285
+ + second_response.get("usage", {}).get("completion_tokens", 0),
286
+ message_scratchpad=message,
287
+ plugin_cost=plugin_cost,
288
+ plugin_token=plugin_token,
289
+ )
290
+ else:
291
+ message.append(dict(response_message))
292
+ return ChatCompletionWithHistory(
293
+ state="success",
294
+ role=response.choices[0].message["role"],
295
+ content=response.choices[0].message["content"],
296
+ prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
297
+ completion_token=response.get("usage", {}).get(
298
+ "completion_tokens", 0
299
+ ),
300
+ message_scratchpad=message,
301
+ )
302
+
303
+ except Exception as exception:
304
+ print("Exception:", exception)
305
+ return ChatCompletionWithHistory(state="error", content=str(exception))
306
+
307
+ def function_chat_stream_completion(
308
+ self,
309
+ message: List[dict],
310
+ function_map: Dict[str, Callable],
311
+ function_schema: List[Dict],
312
+ ) -> ChatCompletionWithHistory:
313
+ assert len(function_schema) == len(function_map)
314
+ try:
315
+ response = openai.ChatCompletion.create(
316
+ n=self.params.n,
317
+ model=self.get_model_name(),
318
+ messages=message,
319
+ functions=function_schema,
320
+ temperature=self.params.temperature,
321
+ max_tokens=self.params.max_tokens,
322
+ top_p=self.params.top_p,
323
+ frequency_penalty=self.params.frequency_penalty,
324
+ presence_penalty=self.params.presence_penalty,
325
+ stream=True,
326
+ )
327
+ tmp = next(response)
328
+ role = tmp.choices[0].delta["role"]
329
+ _type = (
330
+ "function_call"
331
+ if tmp.choices[0].delta["content"] is None
332
+ else "content"
333
+ )
334
+ if _type == "function_call":
335
+ name = tmp.choices[0].delta["function_call"]["name"]
336
+ yield _type, ChatCompletionWithHistory(
337
+ state="success",
338
+ role=role,
339
+ content="{" + f'"name":"{name}", "arguments":',
340
+ message_scratchpad=message,
341
+ )
342
+ for resp in response:
343
+ # print(resp)
344
+ content = resp.choices[0].delta.get(_type, "")
345
+ if isinstance(content, dict):
346
+ content = content["arguments"]
347
+ yield _type, ChatCompletionWithHistory(
348
+ state="success",
349
+ role=role,
350
+ content=content,
351
+ message_scratchpad=message,
352
+ )
353
+
354
+ # result = ''.join(messages)
355
+ # if _type == "function_call":
356
+ # result = json.loads(result)
357
+ # function_name = result["name"]
358
+ # fuction_to_call = function_map[function_name]
359
+ # function_args = result["arguments"]
360
+ # function_response = fuction_to_call(**function_args)
361
+ #
362
+ # # Postprocess function response
363
+ # if isinstance(function_response, AgentOutput):
364
+ # function_response = function_response.output
365
+ # message.append({"role": "function",
366
+ # "name": function_name,
367
+ # "content": function_response})
368
+ # second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema)
369
+ # message.append(dict(second_response.choices[0].message))
370
+
371
+ except Exception as e:
372
+ logger.error(f"Failed to get response {str(e)}", exc_info=True)
373
+ raise e
src/infiagent/prompt/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .prompt_template import *
2
+ from .simple_react_prompt import SimpleReactPrompt
3
+ from .zero_shot_react_prompt import ZeroShotReactPrompt
src/infiagent/prompt/prompt_template.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt schema definition."""
2
+ from abc import ABC, abstractmethod
3
+ from string import Formatter
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ from pydantic import BaseModel, Extra, root_validator
7
+
8
+ from ..exceptions.exceptions import InputErrorException
9
+ from ..schemas import AgentAction, AgentObservation, BaseAgentResponse
10
+
11
+ OBSERVATION_KEY = "Observation"
12
+ THOUGHT_KEY = "Thought"
13
+ FINAL_ANSWER_KEY = "FinalAnswer"
14
+
15
+ DEFAULT_OBSERVATION = "Observation:"
16
+ DEFAULT_THOUGHT = "Thought:"
17
+ DEFAULT_FINAL_ANSWER = "Final Answer:"
18
+
19
+
20
+ class PromptTemplate(BaseModel, ABC):
21
+ _input_variables: List[str]
22
+ _template: str
23
+ _keywords: Dict[str, str]
24
+ _name: str
25
+ _validate_template: bool
26
+ _skip_on_failure: bool
27
+
28
+ class Config:
29
+ extra = Extra.forbid
30
+
31
+ @property
32
+ def input_variables(self) -> List[str]:
33
+ return self._input_variables
34
+
35
+ @property
36
+ def template(self) -> str:
37
+ return self._template
38
+
39
+ @property
40
+ def keywords(self) -> Dict[str, str]:
41
+ return self._keywords
42
+
43
+ @property
44
+ def name(self) -> str:
45
+ return self._name
46
+
47
+ def format(self, **kwargs):
48
+ if not set(self._input_variables).issubset(kwargs.keys()):
49
+ missing_keys = set(self._input_variables) - kwargs.keys()
50
+ raise InputErrorException(f"Missing keys in prompt template: {', '.join(missing_keys)}")
51
+
52
+ filtered_kwargs = {key: kwargs[key] for key in self._input_variables if key in kwargs}
53
+
54
+ return self._template.format(**filtered_kwargs)
55
+
56
+ def construct_scratchpad(self, intermediate_steps: List[BaseAgentResponse]) -> str:
57
+ """Construct the scratchpad that lets the agent continue its thought process."""
58
+ thoughts = ""
59
+
60
+ for agent_response in intermediate_steps:
61
+ if isinstance(agent_response, AgentAction):
62
+ # for agent action, use thought
63
+ thoughts += agent_response.raw_output
64
+ elif isinstance(agent_response, AgentObservation):
65
+ # for agent observation use observation
66
+ thoughts += f"\n{self.keywords.get(OBSERVATION_KEY, DEFAULT_OBSERVATION)}\n" \
67
+ f"{agent_response.formatted_output}\n\n" \
68
+ f"{self.keywords.get(THOUGHT_KEY, DEFAULT_THOUGHT)}\n"
69
+
70
+ return thoughts
71
+
72
+ @classmethod
73
+ @root_validator(skip_on_failure=True)
74
+ def template_is_valid(cls, values: Dict) -> Dict:
75
+ """Check that template and input variables are consistent."""
76
+ if values["validate_template"]:
77
+ try:
78
+ dummy_input = {var: "" for var in values["input_variables"]}
79
+ Formatter().format(values["template"], **dummy_input)
80
+ except KeyError as e:
81
+ raise InputErrorException("Invalid prompt schema; check for mismatched or missing input parameters. ")\
82
+ from e
83
+ return values
src/infiagent/prompt/simple_react_prompt.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..prompt import FINAL_ANSWER_KEY, OBSERVATION_KEY, THOUGHT_KEY, PromptTemplate
2
+
3
+
4
+ class SimpleReactPrompt(PromptTemplate):
5
+ _input_variables = ["instruction", "agent_scratchpad"]
6
+ _template = "{instruction} \n{agent_scratchpad}"
7
+ _keywords = {
8
+ OBSERVATION_KEY: "[EOS]Observation:",
9
+ THOUGHT_KEY: "[SEP]",
10
+ FINAL_ANSWER_KEY: "[END]"
11
+ }
12
+ _name = 'SimpleReactPrompt'
13
+ _validate_template = True
14
+ _skip_on_failure = True
15
+
16
+ def __init__(self, **data):
17
+ super().__init__(**data)
src/infiagent/prompt/zero_shot_react_prompt.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..prompt import PromptTemplate, OBSERVATION_KEY, THOUGHT_KEY, FINAL_ANSWER_KEY, DEFAULT_OBSERVATION, \
2
+ DEFAULT_THOUGHT, DEFAULT_FINAL_ANSWER
3
+
4
+
5
+ class ZeroShotReactPrompt(PromptTemplate):
6
+ _input_variables = ["instruction", "agent_scratchpad", "tool_names", "tool_description"]
7
+ _template = (
8
+ "Answer the following questions as best you can."
9
+ "You have access to the following tools:\n"
10
+ "{tool_description}.\n"
11
+ "Use the following format:\n\n"
12
+ "Question: the input question you must answer\n"
13
+ "Thought: you should always think about what to do\n\n"
14
+ "Action: the action to take, should be one of [{tool_names}]\n\n"
15
+ "Action Input:\n```python\n[the input to the action]\n```\n"
16
+ "Observation: the result of the action\n\n"
17
+ "... (this Thought/Action/Action Input/Observation can repeat N times)\n"
18
+ "Thought: I now know the final answer\n"
19
+ "Final Answer: the final answer to the original input question\n"
20
+ "If you have any files outputted write them to \"./\"\n"
21
+ "Do not use things like plot.show() as it will not work instead write them out \"./\"\n"
22
+ "Begin!\n\n"
23
+ "Question: {instruction}\nThought:\n"
24
+ "{agent_scratchpad}\n"
25
+ )
26
+ _keywords = {
27
+ OBSERVATION_KEY: DEFAULT_OBSERVATION,
28
+ THOUGHT_KEY: DEFAULT_THOUGHT,
29
+ FINAL_ANSWER_KEY: DEFAULT_FINAL_ANSWER
30
+ }
31
+ _name = 'ZeroShotReactPrompt'
32
+ _validate_template = True
33
+ _skip_on_failure = True
34
+
35
+ def __init__(self, **data):
36
+ super().__init__(**data)
src/infiagent/schemas/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base_models import *
2
+ from .complete_models import *
3
+ from .sandbox_models import *
4
+ from .agent_models import *
5
+ from .llm_models import *
src/infiagent/schemas/agent_models.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum
6
+ from typing import List, NamedTuple, Optional, Union
7
+
8
+ from pydantic import BaseModel
9
+
10
+ from ..schemas.sandbox_models import *
11
+
12
+
13
+ @dataclass
14
+ class BaseAgentResponse:
15
+ """Base Agent step result, contains formatted output string."""
16
+ formatted_output: str
17
+ raw_output: str
18
+
19
+
20
+ @dataclass
21
+ class AgentAction(BaseAgentResponse):
22
+ """
23
+ Agent's action to take.
24
+ """
25
+ tool: str
26
+ tool_input: Union[str, dict]
27
+
28
+
29
+ @dataclass
30
+ class AgentObservation(BaseAgentResponse):
31
+ """
32
+ Agent's action to take.
33
+ """
34
+ tool: str
35
+
36
+
37
+ @dataclass
38
+ class AgentFinish(BaseAgentResponse):
39
+ """Agent's return value when finishing execution."""
40
+ pass
41
+
42
+
43
+ class AgentType(Enum):
44
+ """
45
+ Enumerated type for agent types.
46
+ """
47
+ openai = "openai"
48
+ react = "react"
49
+ rewoo = "rewoo"
50
+ vanilla = "vanilla"
51
+ openai_memory = "openai_memory"
52
+
53
+ @staticmethod
54
+ def get_agent_class(_type: AgentType):
55
+ """
56
+ Get agent class from agent type.
57
+ :param _type: agent type
58
+ :return: agent class
59
+ """
60
+ if _type == AgentType.react:
61
+ from ..agent.react import ReactAgent
62
+ return ReactAgent
63
+ else:
64
+ raise ValueError(f"Unknown agent type: {_type}")
65
+
66
+
67
+ class AgentOutput(BaseModel):
68
+ """
69
+ Pydantic model for agent output.
70
+ """
71
+ output: str
72
+ cost: float
73
+ token_usage: int
74
+
75
+
76
+ @dataclass
77
+ class AgentRequest:
78
+ sandbox_id: Optional[str] = None
79
+ messages: List[Message] = field(default_factory=list)
80
+ input_files: List[MediaFile] = field(default_factory=list)
81
+ sandbox_status: Optional[SandboxStatus] = None
82
+ is_cn: bool = False
83
+
84
+
85
+
86
+ @dataclass
87
+ class AgentResponse:
88
+ output_text: str
89
+ raw_output_text: str
90
+ output_files: List[MediaFile] = field(default_factory=list)
91
+ sandbox_id: Optional[str] = None
92
+ sandbox_status: Optional[SandboxStatus] = None
93
+ turn_level_prompt: Optional[List[str]] = None
94
+ turn_level_response: Optional[List[str]] = None
95
+
96
+
97
+ class RoleType(Enum):
98
+ User = 0
99
+ System = 1
100
+ Agent = 2
101
+
102
+ @classmethod
103
+ def _missing_(cls, name):
104
+ # If the input is a string, perform case-insensitive matching
105
+ if isinstance(name, str):
106
+ for member in cls:
107
+ if member.name.lower() == name.lower():
108
+ return member
109
+ return super()._missing_(name)
110
+
111
+
112
+ @dataclass
113
+ class Message(abc.ABC):
114
+ role: RoleType
115
+ content: str
116
+ raw_content: str = ""
117
+
118
+ @staticmethod
119
+ def parse_from_dict(data):
120
+ data['role'] = RoleType(data['role'])
121
+ # Add a check for raw_content in legacy data
122
+ if 'raw_content' not in data:
123
+ data['raw_content'] = ""
124
+ return Message(**data)
125
+
126
+ def to_dict(self):
127
+ role_value = self.role.value if isinstance(self.role, RoleType) else self.role
128
+ return {
129
+ "role": role_value,
130
+ "content": self.content, # Fixed the missing comma here
131
+ "raw_content": self.raw_content
132
+ }
133
+
134
+
135
+ @dataclass
136
+ class MediaFile:
137
+ file_name: Optional[str] = None
138
+ file_content: Optional[bytes] = None
139
+ tos_path: Optional[str] = None
140
+ sandbox_path: Optional[str] = None
141
+
142
+ def __dict__(self):
143
+ return {
144
+ 'file_name': self.file_name if self.file_name is not None else "",
145
+ 'file_content': self.file_content if self.file_content is not None else "",
146
+ 'tos_path': self.tos_path if self.tos_path is not None else "",
147
+ 'sandbox_path': self.sandbox_path if self.sandbox_path is not None else "",
148
+ }
src/infiagent/schemas/base_models.py ADDED
File without changes
src/infiagent/schemas/complete_models.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ from datetime import datetime
3
+ from time import time
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from ..schemas.agent_models import Message
9
+ from ..utils.file_utils import get_file_name_and_path
10
+
11
+ # Definitions for inputs and outputs schema for /complete api
12
+
13
+ DEFAULT_TOP_P = 0.7
14
+ DEFAULT_TEMPERATURE = 1.0
15
+ DEFAULT_STREAM = False
16
+
17
+ FINISH_STATUS = "FINISH"
18
+ FAILED_STATUS = "FAILED"
19
+ PROCESSING_STATUS = "PROCESSING"
20
+ ASSISTANT = "assistant"
21
+
22
+
23
+ # Main Input Model
24
+ class ChatCompleteRequest(BaseModel):
25
+ chat_id: str # unique chat id for given chat
26
+ code_interpreter: Optional[dict] = {}
27
+ messages: List[dict] = [] # chat message
28
+ model: str = "AZURE_OPEN_AI" # model name map to LLM conf
29
+ user: str
30
+ max_tokens: Optional[int] = None
31
+ message_conf: Optional[dict] = {}
32
+ n: Optional[int] = None
33
+ plugins: Optional[List[str]] = None
34
+ seed_conf: Optional[dict] = {}
35
+ stream: Optional[bool] = None
36
+ temperature: Optional[float] = None
37
+ top_p: Optional[float] = None
38
+ top_k: Optional[int] = None
39
+ webgpt: Optional[Dict[str, Any]] = None
40
+ webgpt_network: Optional[bool] = None
41
+
42
+
43
+ class MessageConf(BaseModel):
44
+ top_p: float = DEFAULT_TOP_P
45
+ temperature: float = DEFAULT_TEMPERATURE
46
+ top_k: Optional[int] = None
47
+ time_cost: int
48
+ code_interpreter: dict
49
+ gpt_engine_conf: dict
50
+ stream: bool
51
+
52
+
53
+ class Delta(BaseModel):
54
+ role: str
55
+ content: str
56
+ sid: str
57
+ status: str
58
+ end_turn: bool
59
+ parent_id: str
60
+ children_ids: Optional[Union[List[str], None]]
61
+ err_msg: str
62
+ creator: str
63
+ updater: str
64
+ ctime: str
65
+ utime: str
66
+ message_conf: MessageConf
67
+
68
+ def json(self, *args, **kwargs):
69
+ serialized_data = super().json(*args, **kwargs)
70
+ return serialized_data.replace("+00:00", "Z")
71
+
72
+
73
+ class Choice(BaseModel):
74
+ index: int
75
+ delta: Delta
76
+ finish_reason: str
77
+
78
+
79
+ class ChatCompleteResponse(BaseModel):
80
+ id: str
81
+ created: int
82
+ choices: List[Choice]
83
+
84
+
85
+ def chat_request_to_message_conf(chat_request: ChatCompleteRequest) -> MessageConf:
86
+ input_files = {}
87
+
88
+ if chat_request.code_interpreter and "tos_key" in chat_request.code_interpreter:
89
+ input_file = chat_request.code_interpreter["tos_key"]
90
+ file_name, tos_path = get_file_name_and_path(input_file)
91
+ input_files = {"tos_key": file_name}
92
+
93
+ return MessageConf(
94
+ top_p=chat_request.top_p if chat_request.top_p is not None else DEFAULT_TOP_P,
95
+ temperature=chat_request.temperature if chat_request.temperature is not None else DEFAULT_TEMPERATURE,
96
+ code_interpreter=input_files,
97
+ time_cost=0,
98
+ gpt_engine_conf={},
99
+ stream=chat_request.stream if chat_request.stream is not None else DEFAULT_STREAM
100
+ )
101
+
102
+
103
+ def chat_request_to_deltas(chat_request: ChatCompleteRequest) -> List[Delta]:
104
+ deltas = []
105
+ message_conf = chat_request_to_message_conf(chat_request)
106
+
107
+ for message in chat_request.messages:
108
+ delta = Delta(
109
+ role=ASSISTANT,
110
+ content=message["content"],
111
+ sid="",
112
+ status="FINISH",
113
+ end_turn=False,
114
+ parent_id="",
115
+ children_ids=None,
116
+ err_msg="",
117
+ creator=chat_request.user,
118
+ updater=chat_request.user,
119
+ ctime=current_utc_time_as_str(),
120
+ utime=current_utc_time_as_str(),
121
+ message_conf=message_conf
122
+ )
123
+ deltas.append(delta)
124
+
125
+ return deltas
126
+
127
+
128
+ def chat_request_to_choices(chat_request: ChatCompleteRequest) -> List[Choice]:
129
+ deltas = chat_request_to_deltas(chat_request)
130
+ choices = []
131
+
132
+ for index, delta in enumerate(deltas):
133
+ choice = Choice(
134
+ index=index,
135
+ delta=delta,
136
+ finish_reason="stop"
137
+ )
138
+ choices.append(choice)
139
+
140
+ return choices
141
+
142
+
143
+ def chat_request_to_response(chat_request: ChatCompleteRequest) -> ChatCompleteResponse:
144
+ return ChatCompleteResponse(
145
+ id=chat_request.chat_id,
146
+ created=int(time()),
147
+ choices=chat_request_to_choices(chat_request)
148
+ )
149
+
150
+
151
+ def update_chat_response_with_message(chat_response: ChatCompleteResponse,
152
+ message: Message,
153
+ status: Union[str, None] = None) -> ChatCompleteResponse:
154
+ # Get the last Delta (if exists)
155
+ last_delta = chat_response.choices[-1].delta if chat_response.choices else None
156
+ updated_delta = Delta(
157
+ role=ASSISTANT, # map with front end
158
+ content=message.content,
159
+ sid=last_delta.sid if last_delta else "",
160
+ status=status if status is not None else FINISH_STATUS,
161
+ end_turn=False,
162
+ parent_id=last_delta.parent_id if last_delta else "",
163
+ children_ids=last_delta.children_ids if last_delta else None,
164
+ err_msg="",
165
+ creator=last_delta.creator if last_delta else None,
166
+ updater=last_delta.updater if last_delta else None,
167
+ ctime=last_delta.ctime if last_delta else current_utc_time_as_str(),
168
+ utime=current_utc_time_as_str(),
169
+ message_conf=MessageConf(
170
+ top_p=last_delta.message_conf.top_p if last_delta and last_delta.message_conf.top_p else DEFAULT_TOP_P,
171
+ temperature=last_delta.message_conf.temperature if last_delta and last_delta.message_conf.temperature else
172
+ DEFAULT_TEMPERATURE,
173
+ code_interpreter=last_delta.message_conf.code_interpreter
174
+ if last_delta and last_delta.message_conf.code_interpreter else {},
175
+ time_cost=0,
176
+ gpt_engine_conf={},
177
+ stream=last_delta.message_conf.stream if last_delta and last_delta.message_conf.stream is not None else
178
+ False
179
+ )
180
+ )
181
+
182
+ updated_choice = Choice(
183
+ index=0, # Since it's the only choice in the list
184
+ delta=updated_delta,
185
+ finish_reason="stop"
186
+ )
187
+
188
+ # Update the ChatCompleteResponse to contain only the new Choice
189
+ chat_response.choices = [updated_choice]
190
+ return chat_response
191
+
192
+
193
+ def current_utc_time_as_str() -> str:
194
+ return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
195
+
196
+
197
+ def create_empty_response():
198
+ # Dummy instance for Delta
199
+ delta = Delta(
200
+ role=ASSISTANT,
201
+ content="",
202
+ sid="",
203
+ status="",
204
+ end_turn=False,
205
+ parent_id="",
206
+ children_ids=None,
207
+ err_msg="",
208
+ creator="",
209
+ updater="",
210
+ ctime="",
211
+ utime="",
212
+ message_conf=MessageConf(
213
+ top_p=0.0,
214
+ temperature=0,
215
+ time_cost=0,
216
+ code_interpreter={},
217
+ gpt_engine_conf={},
218
+ stream=False
219
+ )
220
+ )
221
+
222
+ # Dummy instance for Choice
223
+ choice = Choice(
224
+ index=0,
225
+ delta=delta,
226
+ finish_reason=""
227
+ )
228
+
229
+ # Dummy instance for ChatCompleteResponse
230
+ response = ChatCompleteResponse(
231
+ id="",
232
+ created=0,
233
+ choices=[choice]
234
+ )
235
+ return response
236
+
src/infiagent/schemas/llm_models.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Dict, List, NamedTuple, Union
6
+
7
+ from pydantic import BaseModel
8
+
9
+ try:
10
+ import torch
11
+ except ImportError:
12
+ pass
13
+
14
+
15
+
16
+ class BaseCompletion(BaseModel):
17
+ state: str # "success" or "error"
18
+ content: str
19
+ prompt_token: int = 0
20
+ completion_token: int = 0
21
+
22
+ def to_dict(self):
23
+ return dict(
24
+ state=self.state,
25
+ content=self.content,
26
+ prompt_token=self.prompt_token,
27
+ completion_token=self.completion_token,
28
+ )
29
+
30
+
31
+ class ChatCompletion(BaseCompletion):
32
+ role: str = "assistant" # "system" or "user" or "assistant"
33
+
34
+
35
+ class ChatCompletionWithHistory(ChatCompletion):
36
+ """Used for function call API"""
37
+ message_scratchpad: List[Dict] = []
38
+ plugin_cost: float = 0.0
39
+ plugin_token: float = 0.0
40
+
41
+
42
+ class BaseParamModel(BaseModel):
43
+ def __eq__(self, other):
44
+ return self.dict() == other.dict()
45
+
46
+
47
+ class OpenAIParamModel(BaseModel):
48
+ """
49
+ OpenAI API parameters
50
+ """
51
+ max_tokens: int = 2048
52
+ temperature: float = 0.2
53
+ top_p: float = 1.0
54
+ presence_penalty: float = 0.0
55
+ frequency_penalty: float = 0.0
56
+ n: int = 1
57
+ stop: list = []
58
+
59
+ class AzureOpenAIParamModel(BaseModel):
60
+ """
61
+ AzureOpenAI API parameters
62
+ """
63
+ max_tokens: int = 2048
64
+ temperature: float = 0.2
65
+ top_p: float = 1.0
66
+ presence_penalty: float = 0.0
67
+ frequency_penalty: float = 0.0
68
+ n: int = 1
69
+ stop: list = []
70
+
71
+ class LlamaParamModel(BaseModel):
72
+ """
73
+ AzureOpenAI API parameters
74
+ """
75
+ max_tokens: int = 4096
76
+ temperature: float = 0.2
77
+ top_p: float = 1.0
78
+ presence_penalty: float = 0.0
79
+ frequency_penalty: float = 0.0
80
+ n: int = 1
81
+ stop: list = []
82
+
83
+ class OptParamModel(BaseModel):
84
+ """
85
+ AzureOpenAI API parameters
86
+ """
87
+ max_tokens: int = 2048
88
+ temperature: float = 0.2
89
+ top_p: float = 1.0
90
+ n: int = 1
91
+ stop: list = []
src/infiagent/schemas/sandbox_models.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Any, List, Optional
3
+ from pydantic import BaseModel
4
+
5
+ class SandboxStatus(Enum):
6
+ """
7
+ Enumerated type for agent types.
8
+ """
9
+ success = "success"
10
+ failed = "failed"
11
+ timeout = "timeout"
12
+
13
+ class CodeOutput(BaseModel):
14
+ type: str
15
+ content: str
16
+
17
+ class ReturnedFile(BaseModel):
18
+ download_link: str
19
+ name: str
20
+ path: str
21
+
22
+ class CodeRunResult(BaseModel):
23
+ code_output_result: List[CodeOutput]
24
+ deleted_files: List[ReturnedFile]
25
+ new_generated_files: List[ReturnedFile]
26
+
27
+ class CodeRunData(BaseModel):
28
+ is_partial: bool
29
+ result: CodeRunResult
30
+
31
+
32
+ class RunCodeOutput(BaseModel):
33
+ code: int
34
+ message: str
35
+ data: Optional[CodeRunData]
36
+
37
+ class CreateSessionOutput(BaseModel):
38
+ code: int
39
+ message: str
40
+
41
+
42
+ class ErrorResponse(BaseModel):
43
+ code: int
44
+ message: str
45
+ data: Optional[Any]
46
+
47
+
48
+ class UploadOutput(BaseModel):
49
+ code: int
50
+ message: Optional[str]
51
+ data: Optional[str]
52
+
53
+
54
+ # Model for successful response (assuming it's a text file for this example)
55
+ class DownloadSuccessOutput(BaseModel):
56
+ file_name: str # this is not part of server response. We must fill this field in client.
57
+ content: str
58
+
59
+
60
+ class HeartbeatOutput(BaseModel):
61
+ code: Optional[int]
62
+ message: Optional[str]
63
+
64
+
65
+ class RefreshSandboxOutput(BaseModel):
66
+ code: Optional[int]
67
+ message: Optional[str]
68
+
69
+
src/infiagent/services/__init__.py ADDED
File without changes
src/infiagent/services/chat_complete_service.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from io import BytesIO
3
+ from typing import Any, Dict, List, Union
4
+
5
+ from fastapi import UploadFile
6
+ from starlette.datastructures import UploadFile as StarletteUploadFile
7
+ from werkzeug.datastructures import FileStorage
8
+
9
+ from ..conversation_sessions import CodeInterpreterSession
10
+ from ..exceptions.exceptions import (
11
+ DependencyException,
12
+ InputErrorException,
13
+ InternalErrorException,
14
+ ModelMaxIterationsException,
15
+ )
16
+ from ..schemas import Message, RoleType
17
+ from ..utils import get_logger
18
+ from ..tools import AsyncPythonSandBoxTool
19
+
20
+ logger = get_logger()
21
+
22
+
23
+ async def predict(
24
+ prompt: str,
25
+ model_name: str,
26
+ config_path: str,
27
+ uploaded_files: Any,
28
+ **kwargs: Dict[str, Any]):
29
+ start_time = time.time()
30
+
31
+ # create new session
32
+ session = await CodeInterpreterSession.create(
33
+ model_name=model_name,
34
+ config_path=config_path,
35
+ **kwargs
36
+ )
37
+
38
+ files = upload_files(uploaded_files, session.session_id)
39
+ logger.info(f"Session Creation Latency: {time.time() - start_time}")
40
+
41
+ # upload file
42
+ if isinstance(files, str):
43
+ logger.info(f"Upload {files} as file path")
44
+ await session.upload_to_sandbox(files)
45
+ # upload list of file
46
+ elif isinstance(files, list):
47
+ for file in files:
48
+ if isinstance(file, str):
49
+ await session.upload_to_sandbox(file)
50
+ elif isinstance(file, UploadFile) or isinstance(file, StarletteUploadFile):
51
+ file_content = file.file.read() # get file content
52
+ file_like_object = BytesIO(file_content)
53
+ file_storage = FileStorage(
54
+ stream=file_like_object,
55
+ filename=file.filename,
56
+ content_type=file.content_type
57
+ )
58
+ await session.upload_to_sandbox(file_storage)
59
+ else:
60
+ raise InputErrorException("The file type {} not supported, can't be uploaded".format(type(file)))
61
+
62
+ # chat
63
+ try:
64
+ logger.info(f"Instruction message: {prompt}")
65
+ content = None
66
+ output_files = []
67
+ user_messages = [Message(RoleType.User, prompt)]
68
+ async for response in session.chat(user_messages):
69
+ logger.info(f'Session Chat Response: {response}')
70
+ if content is None:
71
+ content = response.output_text
72
+ else:
73
+ content += response.output_text
74
+
75
+ output_files.extend([output_file.__dict__() for output_file in response.output_files])
76
+
77
+ session.messages.append(Message(RoleType.Agent, content))
78
+ AsyncPythonSandBoxTool.kill_kernels(session.session_id)
79
+ logger.info(f"Release python sandbox {session.session_id}")
80
+ logger.info(f"Total Latency: {time.time() - start_time}")
81
+
82
+ return content
83
+
84
+ except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \
85
+ as e:
86
+ exception_messages = {
87
+ ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, "
88
+ "Please try another question.",
89
+ DependencyException: "Agent failed to process message due to dependency issue. You can try it later. "
90
+ "If it still happens, please contact oncall.",
91
+ InputErrorException: "Agent failed to process message due to value issue. If you believe all input are "
92
+ "correct, please contact oncall.",
93
+ InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.",
94
+ Exception: "Agent failed to process message due to unknown error, please contact oncall."
95
+ }
96
+ err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}")
97
+ logger.error(err_msg, exc_info=True)
98
+
99
+ raise Exception(err_msg)
100
+
101
+ import time
102
+ from typing import Union, List, Any, Dict
103
+ from io import BytesIO
104
+
105
+ from fastapi import UploadFile
106
+ from starlette.datastructures import UploadFile as StarletteUploadFile
107
+
108
+ from ..conversation_sessions import CodeInterpreterSession
109
+ from ..schemas import (
110
+ Message,
111
+ RoleType
112
+ )
113
+ from werkzeug.datastructures import FileStorage
114
+
115
+ from ..exceptions.exceptions import InputErrorException, DependencyException, InternalErrorException, \
116
+ ModelMaxIterationsException
117
+
118
+ from ..utils import get_logger, upload_files
119
+
120
+ logger = get_logger()
121
+
122
+
123
+ async def predict(
124
+ prompt: str,
125
+ model_name: str,
126
+ uploaded_files: Any,
127
+ **kwargs: Dict[str, Any]):
128
+ start_time = time.time()
129
+
130
+ # create new session
131
+ session = await CodeInterpreterSession.create(
132
+ model_name=model_name,
133
+ **kwargs
134
+ )
135
+
136
+ files = upload_files(uploaded_files, session.session_id)
137
+ logger.info(f"Session Creation Latency: {time.time() - start_time}")
138
+
139
+ # upload file
140
+ if isinstance(files, str):
141
+ logger.info(f"Upload {files} as file path")
142
+ await session.upload_to_sandbox(files)
143
+ # upload list of file
144
+ elif isinstance(files, list):
145
+ for file in files:
146
+ if isinstance(file, str):
147
+ await session.upload_to_sandbox(file)
148
+ elif isinstance(file, UploadFile) or isinstance(file, StarletteUploadFile):
149
+ file_content = file.file.read() # get file content
150
+ file_like_object = BytesIO(file_content)
151
+ file_storage = FileStorage(
152
+ stream=file_like_object,
153
+ filename=file.filename,
154
+ content_type=file.content_type
155
+ )
156
+ await session.upload_to_sandbox(file_storage)
157
+ else:
158
+ raise InputErrorException("The file type {} not supported, can't be uploaded".format(type(file)))
159
+
160
+ # chat
161
+ try:
162
+ logger.info(f"Instruction message: {prompt}")
163
+ content = None
164
+ output_files = []
165
+ user_messages = [Message(RoleType.User, prompt)]
166
+
167
+ async for response in session.chat(user_messages):
168
+ logger.info(f'Session Chat Response: {response}')
169
+ if content is None:
170
+ content = response.output_text
171
+ else:
172
+ content += response.output_text
173
+
174
+ output_files.extend([output_file.__dict__() for output_file in response.output_files])
175
+
176
+ session.messages.append(Message(RoleType.Agent, content))
177
+
178
+ logger.info(f"Total Latency: {time.time() - start_time}")
179
+
180
+ return content
181
+ except (ModelMaxIterationsException, DependencyException, InputErrorException, InternalErrorException, Exception) \
182
+ as e:
183
+ exception_messages = {
184
+ ModelMaxIterationsException: "Sorry. The agent didn't find the correct answer after multiple trials, "
185
+ "Please try another question.",
186
+ DependencyException: "Agent failed to process message due to dependency issue. You can try it later. "
187
+ "If it still happens, please contact oncall.",
188
+ InputErrorException: "Agent failed to process message due to value issue. If you believe all input are "
189
+ "correct, please contact oncall.",
190
+ InternalErrorException: "Agent failed to process message due to internal error, please contact oncall.",
191
+ Exception: "Agent failed to process message due to unknown error, please contact oncall."
192
+ }
193
+ err_msg = exception_messages.get(type(e), f"Unknown error occurred: {str(e)}")
194
+ logger.error(err_msg, exc_info=True)
195
+
196
+ raise Exception(err_msg)