Beracles's picture
add calling openai assistant
f15d3a0
from fastapi.responses import JSONResponse
from fastapi import APIRouter, status
from openai import OpenAI
import os
API_KEY = os.environ.get("openai_api_key")
ASSISTANT_ID = os.environ.get("openai_assistant_id")
router = APIRouter()
client = OpenAI(api_key=API_KEY)
@router.get("/openai")
async def call_openai(prompt: str):
thread = client.beta.threads.create(
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
}
],
}
]
)
run = client.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=ASSISTANT_ID,
)
start_time = run.started_at
while True:
messages = list(
client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
)
if messages:
print("received {} messages".format(len(messages)))
break
end_time = run.completed_at
if end_time - start_time > 60:
break
if not messages:
return JSONResponse(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
content={
"code": run.code,
"message": run.message,
},
)
message_content = messages[0].content[0].text
annotations = message_content.annotations
citations = []
for index, annotation in enumerate(annotations):
message_content.value = message_content.value.replace(
annotation.text, f"[{index}]"
)
file_citation = getattr(annotation, "file_citation", None)
if file_citation:
cited_file = client.files.retrieve(file_citation.file_id)
citations.append(f"[{index}] {cited_file.filename}")
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"assistant_id": ASSISTANT_ID,
"thread_id": thread.id,
"run_id": run.id,
"answer": message_content.value,
"citations": citations,
},
)