Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| from typing import List | |
| pwd = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(pwd, '../../')) | |
| import requests | |
| from project_settings import environment | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--api_key", | |
| default=environment.get("agent_x_api_key", default=None), | |
| type=str | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| class AgentX(object): | |
| def __init__(self, | |
| api_key: str, | |
| agent_name: str = "NXLink智能助手", | |
| url_host: str = "https://api.agentx.so" | |
| ): | |
| self.api_key = api_key | |
| self.agent_name = agent_name | |
| self.url_host = url_host | |
| self.agent_id = self.get_agent_id() | |
| def __str__(self): | |
| result = "<{}; agent_name: {}; agent_id: {}; api_key: {}>".format( | |
| self.__class__.__name__, self.agent_name, self.agent_id, self.api_key) | |
| return result | |
| def get_agent_id(self): | |
| url = "{}/api/v1/access/agents".format(self.url_host) | |
| headers = { | |
| "accept": "*/*", | |
| "x-api-key": self.api_key | |
| } | |
| resp = requests.request( | |
| "GET", | |
| url=url, | |
| headers=headers, | |
| ) | |
| if resp.status_code != 200: | |
| print(resp.status_code) | |
| print(resp.text) | |
| exit(0) | |
| js = resp.json() | |
| result = None | |
| for e in js: | |
| if e["name"] == self.agent_name: | |
| result = e["_id"] | |
| if result is None: | |
| raise AssertionError("agent not found") | |
| return result | |
| def get_agent_config(self): | |
| url = "{}/api/v1/access/agents/{}".format(self.url_host, self.agent_id) | |
| headers = { | |
| "accept": "*/*", | |
| "x-api-key": self.api_key | |
| } | |
| resp = requests.request( | |
| "GET", | |
| url=url, | |
| headers=headers, | |
| ) | |
| js = resp.json() | |
| return js | |
| def get_conversation_list(self): | |
| url = "{}/api/v1/access/agents/{}/conversations".format(self.url_host, self.agent_id) | |
| headers = { | |
| "accept": "*/*", | |
| "x-api-key": self.api_key | |
| } | |
| resp = requests.request( | |
| "GET", | |
| url=url, | |
| headers=headers, | |
| ) | |
| js = resp.json() | |
| return js | |
| def post_message(self, message: str, conversation_id: str, context: int = 0): | |
| url = "{}/api/v1/access/conversations/{}/message".format(self.url_host, conversation_id) | |
| headers = { | |
| "accept": "*/*", | |
| "Content-type": "application/json", | |
| "x-api-key": self.api_key | |
| } | |
| data = { | |
| "message": message, | |
| "context": context, | |
| } | |
| resp = requests.request( | |
| "POST", | |
| url=url, | |
| headers=headers, | |
| data=json.dumps(data) | |
| ) | |
| if resp.status_code != 200: | |
| print(resp.status_code) | |
| print(resp.text) | |
| exit(0) | |
| js = resp.json() | |
| return js | |
| def post_message_by_sse(self, message: str, conversation_id: str, context: int = 0): | |
| url = "{}/api/v1/access/conversations/{}/messagesse".format(self.url_host, conversation_id) | |
| headers = { | |
| "accept": "*/*", | |
| "Content-type": "application/json", | |
| "x-api-key": self.api_key | |
| } | |
| data = { | |
| "message": message, | |
| "context": context, | |
| } | |
| resp = requests.request( | |
| "POST", | |
| url=url, | |
| headers=headers, | |
| data=json.dumps(data), | |
| stream=True | |
| ) | |
| # print(resp.headers) | |
| trace_id = resp.headers["x-trace-id"] | |
| if resp.status_code == 200: | |
| def generator(): | |
| result = "" | |
| buf = b"" | |
| for chunk in resp.iter_content(): | |
| buf += chunk | |
| try: | |
| chunk = buf.decode("utf-8") | |
| except UnicodeDecodeError: | |
| continue | |
| result += chunk | |
| buf = b"" | |
| yield chunk | |
| return generator(), trace_id | |
| else: | |
| print(resp.status_code) | |
| print(resp.headers["Content-Type"]) | |
| raise AssertionError | |
| def get_trace_by_message_id(self, message_id: str): | |
| url = "{}/api/v1/access/messages/{}/trace".format(self.url_host, message_id) | |
| headers = { | |
| "accept": "*/*", | |
| "x-api-key": self.api_key | |
| } | |
| resp = requests.request( | |
| "GET", | |
| url=url, | |
| headers=headers, | |
| ) | |
| js = resp.json() | |
| return js | |
| def get_trace_by_trace_id(self, trace_id: str): | |
| url = "{}/api/v1/access/traces/{}".format(self.url_host, trace_id) | |
| headers = { | |
| "accept": "*/*", | |
| "x-api-key": self.api_key | |
| } | |
| resp = requests.request( | |
| "GET", | |
| url=url, | |
| headers=headers, | |
| ) | |
| js = resp.json() | |
| return js | |
| def post_new_conversation_id(self): | |
| url = "{}/api/v1/access/agents/{}/conversations/new".format(self.url_host, self.agent_id) | |
| headers = { | |
| "accept": "*/*", | |
| "x-api-key": self.api_key | |
| } | |
| resp = requests.request( | |
| "POST", | |
| url=url, | |
| headers=headers, | |
| ) | |
| js = resp.json() | |
| conversation_id = js["_id"] | |
| return conversation_id | |
| def delete_conversation(self, conversation_id: str): | |
| url = "{}/api/v1/access/conversations/{}".format(self.url_host, conversation_id) | |
| headers = { | |
| "accept": "*/*", | |
| "Content-type": "application/json", | |
| "x-api-key": self.api_key | |
| } | |
| resp = requests.request( | |
| "DELETE", | |
| url=url, | |
| headers=headers, | |
| ) | |
| js = resp.json() | |
| return js | |
| def update_context(self, messages: List[dict], conversation_id: str): | |
| url = "{}/api/v1/access/conversations/{}/update-context".format(self.url_host, conversation_id) | |
| headers = { | |
| "accept": "*/*", | |
| "Content-type": "application/json", | |
| "x-api-key": self.api_key | |
| } | |
| data = { | |
| "messages": messages, | |
| } | |
| resp = requests.request( | |
| "PUT", | |
| url=url, | |
| headers=headers, | |
| data=json.dumps(data), | |
| ) | |
| js = resp.json() | |
| return js | |
| def question_answer(self, question: str, conversation_id: str = None, context: List[dict] = None, streaming: bool = False): | |
| if conversation_id is None: | |
| conversation_id = self.post_new_conversation_id() | |
| if context is not None: | |
| self.update_context(context, conversation_id) | |
| result = { | |
| "answer": None, | |
| "reference": None | |
| } | |
| try: | |
| if streaming: | |
| resp_stream, trace_id = self.post_message_by_sse(question, conversation_id, | |
| context=0 if context is None else 1) | |
| answer = "" | |
| for chunk in resp_stream: | |
| print(chunk, end="") | |
| answer += chunk | |
| print("\n") | |
| result["answer"] = answer | |
| # print(answer) | |
| # exit(0) | |
| # [{"title": "", "source": ""}, ...] | |
| trace = self.get_trace_by_trace_id(trace_id) | |
| if trace == "No trace": | |
| reference = "No trace" | |
| else: | |
| reference = list() | |
| for t in trace: | |
| reference.append((t["title"], t["source"])) | |
| result["reference"] = reference | |
| else: | |
| js = self.post_message(question, conversation_id, | |
| context=0 if context is None else 1) | |
| answer = js["text"] | |
| result["answer"] = answer | |
| message_id = js["_id"] | |
| trace = self.get_trace_by_message_id(message_id) | |
| # print(trace) | |
| if trace == "No trace": | |
| reference = "No trace" | |
| else: | |
| reference = list() | |
| for t in trace: | |
| reference.append((t["title"], t["source"])) | |
| result["reference"] = reference | |
| finally: | |
| self.delete_conversation(conversation_id) | |
| return result | |
| def main(): | |
| args = get_args() | |
| agent = AgentX( | |
| api_key=args.api_key, | |
| agent_name="Yutong Bus", | |
| ) | |
| print(agent) | |
| context = [ | |
| { | |
| "user": "你好" | |
| }, | |
| { | |
| "assistant": "你好,我们是宇通客车公司,有什么可以帮到您的吗?" | |
| }, | |
| { | |
| "user": "需要一辆55座客车。" | |
| }, | |
| { | |
| "assistant": "Which country will the bus be used in?" | |
| }, | |
| { | |
| "user": "你可以说中文吗。" | |
| }, | |
| { | |
| "assistant": "可以的,请问您需要在哪个国家使用客车?" | |
| }, | |
| ] | |
| question = "你好" | |
| time_begin = time.time() | |
| response = agent.question_answer(question, context=context, streaming=True) | |
| time_cost = time.time() - time_begin | |
| print(response) | |
| print("time cost: {}".format(time_cost)) | |
| return | |
| if __name__ == '__main__': | |
| main() | |