from fastapi.responses import JSONResponse from fastapi import APIRouter, status from openai import OpenAI import os API_KEY = os.environ.get("openai_api_key") ASSISTANT_ID = os.environ.get("openai_assistant_id") router = APIRouter() client = OpenAI(api_key=API_KEY) @router.get("/openai") async def call_openai(prompt: str): thread = client.beta.threads.create( messages=[ { "role": "user", "content": [ { "type": "text", "text": prompt, } ], } ] ) run = client.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=ASSISTANT_ID, ) start_time = run.started_at while True: messages = list( client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) ) if messages: print("received {} messages".format(len(messages))) break end_time = run.completed_at if end_time - start_time > 60: break if not messages: return JSONResponse( status_code=status.HTTP_408_REQUEST_TIMEOUT, content={ "code": run.code, "message": run.message, }, ) message_content = messages[0].content[0].text annotations = message_content.annotations citations = [] for index, annotation in enumerate(annotations): message_content.value = message_content.value.replace( annotation.text, f"[{index}]" ) file_citation = getattr(annotation, "file_citation", None) if file_citation: cited_file = client.files.retrieve(file_citation.file_id) citations.append(f"[{index}] {cited_file.filename}") return JSONResponse( status_code=status.HTTP_200_OK, content={ "assistant_id": ASSISTANT_ID, "thread_id": thread.id, "run_id": run.id, "answer": message_content.value, "citations": citations, }, )