|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi import APIRouter, status |
|
|
from openai import OpenAI |
|
|
import os |
|
|
|
|
|
API_KEY = os.environ.get("openai_api_key") |
|
|
ASSISTANT_ID = os.environ.get("openai_assistant_id") |
|
|
router = APIRouter() |
|
|
client = OpenAI(api_key=API_KEY) |
|
|
|
|
|
|
|
|
@router.get("/openai") |
|
|
async def call_openai(prompt: str): |
|
|
thread = client.beta.threads.create( |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": prompt, |
|
|
} |
|
|
], |
|
|
} |
|
|
] |
|
|
) |
|
|
run = client.beta.threads.runs.create_and_poll( |
|
|
thread_id=thread.id, |
|
|
assistant_id=ASSISTANT_ID, |
|
|
) |
|
|
start_time = run.started_at |
|
|
while True: |
|
|
messages = list( |
|
|
client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) |
|
|
) |
|
|
if messages: |
|
|
print("received {} messages".format(len(messages))) |
|
|
break |
|
|
end_time = run.completed_at |
|
|
if end_time - start_time > 60: |
|
|
break |
|
|
if not messages: |
|
|
return JSONResponse( |
|
|
status_code=status.HTTP_408_REQUEST_TIMEOUT, |
|
|
content={ |
|
|
"code": run.code, |
|
|
"message": run.message, |
|
|
}, |
|
|
) |
|
|
message_content = messages[0].content[0].text |
|
|
annotations = message_content.annotations |
|
|
citations = [] |
|
|
for index, annotation in enumerate(annotations): |
|
|
message_content.value = message_content.value.replace( |
|
|
annotation.text, f"[{index}]" |
|
|
) |
|
|
file_citation = getattr(annotation, "file_citation", None) |
|
|
if file_citation: |
|
|
cited_file = client.files.retrieve(file_citation.file_id) |
|
|
citations.append(f"[{index}] {cited_file.filename}") |
|
|
return JSONResponse( |
|
|
status_code=status.HTTP_200_OK, |
|
|
content={ |
|
|
"assistant_id": ASSISTANT_ID, |
|
|
"thread_id": thread.id, |
|
|
"run_id": run.id, |
|
|
"answer": message_content.value, |
|
|
"citations": citations, |
|
|
}, |
|
|
) |
|
|
|