Beracles commited on
Commit
f15d3a0
·
1 Parent(s): fbc9217

add calling openai assistant

Browse files
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  fastapi
2
  uvicorn[standard]
3
  dashscope
 
 
1
  fastapi
2
  uvicorn[standard]
3
  dashscope
4
+ openai
src/talk_to_your_manual/__init__.py CHANGED
@@ -1,9 +1,10 @@
1
  from fastapi import APIRouter
2
- from . import use_aliyun
3
 
4
 
5
  router = APIRouter(
6
  prefix="/talk-to-your-manual",
7
  tags=["Talk To Your Manual"],
8
  )
9
- router.include_router(use_aliyun.router)
 
 
1
  from fastapi import APIRouter
2
+ from . import use_aliyun, use_openai
3
 
4
 
5
  router = APIRouter(
6
  prefix="/talk-to-your-manual",
7
  tags=["Talk To Your Manual"],
8
  )
9
+ router.include_router(use_aliyun.router)
10
+ router.include_router(use_openai.router)
src/talk_to_your_manual/use_aliyun.py CHANGED
@@ -5,8 +5,8 @@ from fastapi.responses import JSONResponse
5
  import os
6
 
7
  router = APIRouter()
8
- API_KEY = os.environ.get("API_KEY")
9
- APP_ID = os.environ.get("APP_ID")
10
 
11
  session_id = None
12
 
@@ -49,6 +49,7 @@ async def call_aliyun(prompt: str):
49
  prompt=prompt,
50
  session_id=session_id,
51
  )
 
52
  if response.status_code == HTTPStatus.OK:
53
  last_session_id = session_id
54
  session_id = response.output.session_id
 
5
  import os
6
 
7
  router = APIRouter()
8
+ API_KEY = os.environ.get("aliyun_api_key")
9
+ APP_ID = os.environ.get("aliyun_app_id")
10
 
11
  session_id = None
12
 
 
49
  prompt=prompt,
50
  session_id=session_id,
51
  )
52
+ print(response.output.session_id)
53
  if response.status_code == HTTPStatus.OK:
54
  last_session_id = session_id
55
  session_id = response.output.session_id
src/talk_to_your_manual/use_openai.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.responses import JSONResponse
2
+ from fastapi import APIRouter, status
3
+ from openai import OpenAI
4
+ import os
5
+
6
+ API_KEY = os.environ.get("openai_api_key")
7
+ ASSISTANT_ID = os.environ.get("openai_assistant_id")
8
+ router = APIRouter()
9
+ client = OpenAI(api_key=API_KEY)
10
+
11
+
12
+ @router.get("/openai")
13
+ async def call_openai(prompt: str):
14
+ thread = client.beta.threads.create(
15
+ messages=[
16
+ {
17
+ "role": "user",
18
+ "content": [
19
+ {
20
+ "type": "text",
21
+ "text": prompt,
22
+ }
23
+ ],
24
+ }
25
+ ]
26
+ )
27
+ run = client.beta.threads.runs.create_and_poll(
28
+ thread_id=thread.id,
29
+ assistant_id=ASSISTANT_ID,
30
+ )
31
+ start_time = run.started_at
32
+ while True:
33
+ messages = list(
34
+ client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
35
+ )
36
+ if messages:
37
+ print("received {} messages".format(len(messages)))
38
+ break
39
+ end_time = run.completed_at
40
+ if end_time - start_time > 60:
41
+ break
42
+ if not messages:
43
+ return JSONResponse(
44
+ status_code=status.HTTP_408_REQUEST_TIMEOUT,
45
+ content={
46
+ "code": run.code,
47
+ "message": run.message,
48
+ },
49
+ )
50
+ message_content = messages[0].content[0].text
51
+ annotations = message_content.annotations
52
+ citations = []
53
+ for index, annotation in enumerate(annotations):
54
+ message_content.value = message_content.value.replace(
55
+ annotation.text, f"[{index}]"
56
+ )
57
+ file_citation = getattr(annotation, "file_citation", None)
58
+ if file_citation:
59
+ cited_file = client.files.retrieve(file_citation.file_id)
60
+ citations.append(f"[{index}] {cited_file.filename}")
61
+ return JSONResponse(
62
+ status_code=status.HTTP_200_OK,
63
+ content={
64
+ "assistant_id": ASSISTANT_ID,
65
+ "thread_id": thread.id,
66
+ "run_id": run.id,
67
+ "answer": message_content.value,
68
+ "citations": citations,
69
+ },
70
+ )