Spaces:
Sleeping
Sleeping
| import os | |
| import chromadb | |
| from dotenv import load_dotenv | |
| from uuid import uuid4 | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain.chat_models import init_chat_model | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_chroma import Chroma | |
| import uvicorn | |
| # ---------------------- | |
| # Configuration and Setup | |
| # ---------------------- | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Directories for file upload and persistent storage of Chroma vector database | |
| UPLOAD_DIR = "uploads" | |
| CHROMA_DIR = "chroma_db" | |
| # Set model versions for LLM and embeddings | |
| LLM = "gpt-4o-mini-2024-07-18" | |
| EMBEDDING_MODEL = "text-embedding-3-small" | |
| # Ensure necessary directories exist | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(CHROMA_DIR, exist_ok=True) | |
| # Set OpenAI API key from environment variables | |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| # Initialize a persistent client for Chroma, specifying where the data is stored | |
| client = chromadb.PersistentClient(path=CHROMA_DIR) | |
| # FastAPI application setup | |
| app = FastAPI() | |
| # Enable CORS (Cross-Origin Resource Sharing) for all origins, methods, and headers | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------------------- | |
| # LangChain Initialization | |
| # ---------------------- | |
| # Initialize the embedding model using OpenAI's API | |
| embedding = OpenAIEmbeddings(model=EMBEDDING_MODEL) | |
| # Initialize the language model (LLM) using OpenAI's API (with temperature for creativity) | |
| llm = init_chat_model(model=LLM, model_provider="openai", temperature=0) | |
| # Text splitter to split documents into manageable chunks (for efficient processing) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1200, | |
| chunk_overlap=50, | |
| separators=["\n\n", "\n", ".", " ", ""] | |
| ) | |
| # Set up Chroma vector store to store document embeddings and their metadata | |
| vectorstore = Chroma( | |
| client=client, | |
| persist_directory=CHROMA_DIR, | |
| embedding_function=embedding, | |
| collection_name="legal_docs" | |
| ) | |
| # Define the prompt template that will be used in the LLM for querying with context | |
| prompt_template = """ | |
| Tu es un assistant utile qui réponds en français de manière claire et concise. | |
| Réponds uniquement en utilisant le contexte fourni. | |
| Si tu ne sais pas, dis "Je ne sais pas". | |
| contexte : {context} | |
| question : {question} | |
| answer : | |
| """ | |
| # Initialize the prompt template with variables | |
| prompt = PromptTemplate( | |
| input_variables=["question", "context"], | |
| template=prompt_template, | |
| ) | |
| # Function to format documents for easier reading (used for retriever output) | |
| def format_docs(docs): | |
| return "\n\n".join([f"(Page {d.metadata.get('page','?')}) {d.page_content}" for d in docs]) | |
| # Set up the retriever to pull relevant documents from the vector store based on a query | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) | |
| # Define the QA chain that links together the retriever, document formatting, and LLM for querying | |
| qa_chain = ( | |
| { | |
| "context": retriever | format_docs, | |
| "question": RunnablePassthrough(), | |
| } | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # ---------------------- | |
| # Document Management Functions | |
| # ---------------------- | |
| # Function to add a PDF document to the vector store (embedding and splitting into chunks) | |
| def add_pdf_to_vectorstore(file_path): | |
| # Load the PDF file | |
| loader = PyPDFLoader(file_path) | |
| documents = loader.load() | |
| # Split the document into smaller chunks | |
| docs = text_splitter.split_documents(documents) | |
| # Generate a unique ID for each chunk | |
| uuids = [str(uuid4()) for _ in range(len(docs))] | |
| print(f"Number of documents split: {len(docs)}") | |
| # Add documents to the vector store (Chroma) | |
| vectorstore.add_documents(documents=docs, ids=uuids) | |
| # ---------------------- | |
| # FastAPI Routes | |
| # ---------------------- | |
| # Route to upload a PDF file and add its content to the vector store | |
| async def upload_pdf(file: UploadFile = File(...)): | |
| # Check if the uploaded file is a PDF | |
| if not file.filename.endswith(".pdf"): | |
| raise HTTPException(status_code=400, detail="Seuls les fichiers PDF sont acceptés.") | |
| # Save the uploaded file to disk | |
| file_path = os.path.join(UPLOAD_DIR, file.filename) | |
| with open(file_path, "wb") as buffer: | |
| buffer.write(await file.read()) | |
| # Add the PDF document to the vector store | |
| add_pdf_to_vectorstore(file_path) | |
| # Return a success message | |
| content = {"message": f"Fichier {file.filename} ajouté à la base de connaissances."} | |
| print(f"{content=}") | |
| return JSONResponse(content=content) | |
| # Route to interact with the assistant via a chat-like interface | |
| async def chat(message: str): | |
| # Use the QA chain to get a response from the assistant | |
| response = qa_chain.invoke(message) | |
| # Return the response from the assistant | |
| print(f"{response=}") | |
| return {"answer": response} | |
| # ---------------------- | |
| # Streaming Response for Chat | |
| # ---------------------- | |
| # This function will simulate the streaming of the response. | |
| async def stream_chat_response(message: str): | |
| # Initialize the chat model (this could be done outside the function if it's expensive) | |
| # response_parts = [] | |
| # print("Streaming API response:\n") | |
| async for part in qa_chain.astream(message): | |
| # response_parts.append(part) # Collect all parts of the response | |
| # Yield each part as a chunk for streaming to the client | |
| print(part, end="", flush=True) | |
| yield part | |
| # # Final join to return the complete response after streaming | |
| # full_response = "".join(response_parts) | |
| # yield full_response | |
| # FastAPI endpoint for streaming chat responses | |
| async def chat_stream(message: str): | |
| """ | |
| Endpoint to stream chat responses progressively. | |
| """ | |
| # Return a StreamingResponse that will stream the response from the generator | |
| return StreamingResponse(stream_chat_response(message), media_type="text/plain") | |
| # ---------------------- | |
| # Start the FastAPI app using Uvicorn | |
| # ---------------------- | |
| if __name__ == "__main__": | |
| # Run the FastAPI application with auto-reloading enabled | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |