Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.chains import RetrievalQA | |
| import shutil | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="RAG Chatbot API") | |
| # Ensure directories exist | |
| try: | |
| os.makedirs("documents", exist_ok=True) | |
| os.makedirs("vectorstore", exist_ok=True) | |
| logger.info("Directories 'documents' and 'vectorstore' created or already exist.") | |
| except Exception as e: | |
| logger.error(f"Failed to create directories: {str(e)}") | |
| raise | |
| # Check for GOOGLE_API_KEY | |
| if not os.getenv("GOOGLE_API_KEY"): | |
| logger.error("GOOGLE_API_KEY environment variable not set.") | |
| raise ValueError("GOOGLE_API_KEY environment variable not set.") | |
| # Initialize Gemini LLM | |
| try: | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-1.5-flash", | |
| google_api_key=os.getenv("GOOGLE_API_KEY") | |
| ) | |
| logger.info("Gemini LLM initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Gemini LLM: {str(e)}") | |
| raise | |
| # Initialize embeddings | |
| try: | |
| embeddings = GoogleGenerativeAIEmbeddings( | |
| model="models/embedding-001", | |
| google_api_key=os.getenv("GOOGLE_API_KEY") | |
| ) | |
| logger.info("Gemini embeddings initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Gemini embeddings: {str(e)}") | |
| raise | |
| # Path for vector store | |
| VECTOR_STORE_PATH = "vectorstore/index" | |
| def process_pdf(pdf_path): | |
| """Process and index a PDF document.""" | |
| try: | |
| logger.info(f"Processing PDF: {pdf_path}") | |
| loader = PyPDFLoader(pdf_path) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| texts = text_splitter.split_documents(documents) | |
| if os.path.exists(VECTOR_STORE_PATH): | |
| vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True) | |
| vector_store.add_documents(texts) | |
| logger.info("Added documents to existing FAISS vector store.") | |
| else: | |
| vector_store = FAISS.from_documents(texts, embeddings) | |
| logger.info("Created new FAISS vector store.") | |
| vector_store.save_local(VECTOR_STORE_PATH) | |
| logger.info("Vector store saved successfully.") | |
| return {"status": "Document processed and indexed successfully"} | |
| except Exception as e: | |
| logger.error(f"Error processing PDF: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}") | |
| def answer_query(query): | |
| """Answer a query using the RAG pipeline.""" | |
| if not os.path.exists(VECTOR_STORE_PATH): | |
| logger.warning("No vector store found. Please upload a document first.") | |
| return {"error": "No documents indexed yet. Please upload a document first."} | |
| try: | |
| logger.info(f"Processing query: {query}") | |
| vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vector_store.as_retriever(search_kwargs={"k": 3}), | |
| return_source_documents=True | |
| ) | |
| result = qa_chain({"query": query}) | |
| logger.info("Query processed successfully.") | |
| return { | |
| "answer": result["result"], | |
| "source_documents": [doc.page_content[:200] for doc in result["source_documents"]] | |
| } | |
| except Exception as e: | |
| logger.error(f"Error answering query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error answering query: {str(e)}") | |
| async def upload_document(file: UploadFile = File(...)): | |
| """API to upload and process a PDF document.""" | |
| if not file.filename.endswith(".pdf"): | |
| logger.warning(f"Invalid file type uploaded: {file.filename}") | |
| raise HTTPException(status_code=400, detail="Only PDF files are allowed") | |
| file_path = f"documents/{file.filename}" | |
| try: | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| logger.info(f"Uploaded file saved: {file_path}") | |
| result = process_pdf(file_path) | |
| return JSONResponse(content=result, status_code=200) | |
| except Exception as e: | |
| logger.error(f"Error in upload_document: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}") | |
| async def ask_question(query: str): | |
| """API to answer a query based on indexed documents.""" | |
| logger.info(f"Received question: {query}") | |
| result = answer_query(query) | |
| return JSONResponse(content=result, status_code=200) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| logger.info("Health check requested.") | |
| return {"status": "API is running"} |