Spaces:
Runtime error
Runtime error
| import os | |
| import uvicorn | |
| import tempfile | |
| from openai import AsyncOpenAI | |
| from fastapi import FastAPI, Body, UploadFile, File, Depends, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from aimakerspace.openai_utils.prompts import ( | |
| UserRolePrompt, | |
| SystemRolePrompt, | |
| ) | |
| from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
| from qdrant_client import QdrantClient | |
| from fastapi.security import APIKeyHeader | |
| import uuid | |
| from typing import Dict, Optional | |
| system_template = """\ | |
| Use the following context to answer a users question. | |
| If you cannot find the answer in the context, say you don't know the answer. | |
| """ | |
| system_role_prompt = SystemRolePrompt(system_template) | |
| user_prompt_template = """\ | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| """ | |
| user_role_prompt = UserRolePrompt(user_prompt_template) | |
| app = FastAPI() | |
| openai = AsyncOpenAI() | |
| vector_db = QdrantClient(":memory:") | |
| text_splitter = CharacterTextSplitter() | |
| sessions: Dict[str, dict] = {} | |
| api_key_header = APIKeyHeader(name="X-Session-ID", auto_error=False) | |
| async def get_session(session_id: Optional[str] = Depends(api_key_header)): | |
| if not session_id: | |
| # Create new session | |
| session_id = str(uuid.uuid4()) | |
| sessions[session_id] = { | |
| "vector_db": None, | |
| "vector_db_retriever": None, | |
| } | |
| elif session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return session_id, sessions[session_id] | |
| def process_file(file: UploadFile): | |
| print(f"Processing file: {file.filename}") | |
| # Create a temporary file with the correct extension | |
| suffix = f".{file.filename.split('.')[-1]}" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
| # Write the uploaded file content to the temporary file | |
| content = file.file.read() | |
| temp_file.write(content) | |
| temp_file.flush() | |
| print(f"Created temporary file at: {temp_file.name}") | |
| # Create appropriate loader | |
| if file.filename.lower().endswith('.pdf'): | |
| loader = PDFLoader(temp_file.name) | |
| else: | |
| loader = TextFileLoader(temp_file.name) | |
| try: | |
| # Load and process the documents | |
| documents = loader.load_documents() | |
| texts = text_splitter.split_texts(documents) | |
| return texts | |
| finally: | |
| # Clean up the temporary file | |
| try: | |
| os.unlink(temp_file.name) | |
| except Exception as e: | |
| print(f"Error cleaning up temporary file: {e}") | |
| async def get_response(msg: str, session_id: str, vector_db: QdrantClient): | |
| context_list = vector_db.query( | |
| collection_name=session_id, | |
| query_text=msg, | |
| limit=4, | |
| ) | |
| context_prompt = "" | |
| for context in context_list: | |
| context_prompt += context.document + "\n" | |
| formatted_system_prompt = system_role_prompt.create_message() | |
| formatted_user_prompt = user_role_prompt.create_message(question=msg, context=context_prompt) | |
| openai_stream = await openai.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| formatted_system_prompt, | |
| formatted_user_prompt, | |
| ], | |
| temperature=0.0, | |
| stream=True, | |
| ) | |
| async def generate_response(): | |
| async for chunk in openai_stream: | |
| if chunk.choices[0].delta.content is not None: | |
| yield chunk.choices[0].delta.content | |
| yield "" | |
| return StreamingResponse(generate_response(), media_type="text/event-stream") | |
| async def get_bot_response( | |
| msg: str = Body(...), | |
| session_data: tuple = Depends(get_session) | |
| ): | |
| session_id, _ = session_data | |
| print(f"Session ID: {session_id}") | |
| response = await get_response(msg, session_id, vector_db) | |
| return response | |
| async def get_file_response( | |
| file: UploadFile = File(..., description="A text file to process"), | |
| session_data: tuple = Depends(get_session) | |
| ): | |
| session_id, _ = session_data | |
| print(f"Session ID: {session_id}") | |
| if not file.filename: | |
| return {"error": "No file uploaded"} | |
| try: | |
| chunks = process_file(file) | |
| vector_db.add( | |
| collection_name=session_id, | |
| documents=chunks, | |
| ) | |
| return { | |
| "message": "File processed successfully", | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=422, | |
| content={"detail": str(e)} | |
| ) | |
| app.mount("/", StaticFiles(directory="dist", html=True), name="static") | |
| app.get("/")(StaticFiles(directory="dist", html=True)) | |
| if __name__ == "__main__": | |
| uvicorn.run("server:app") | |