Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import zipfile | |
| from typing import List, Optional, Any | |
| import uuid | |
| from datetime import datetime | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Depends | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| # Removed static files mounting for avatars as avatars are now served via GridFS in auth | |
| #from fastapi.staticfiles import StaticFiles | |
| from llm_initialization import get_llm | |
| from embedding import get_embeddings | |
| from document_loaders import DocumentLoader | |
| from text_splitter import TextSplitter | |
| from vector_store import VectorStoreManager | |
| from prompt_templates import PromptTemplates | |
| from chat_management import ChatManagement | |
| from retrieval_chain import RetrievalChain | |
| from urllib.parse import quote_plus | |
| from dotenv import load_dotenv | |
| from pymongo import MongoClient | |
| # Load environment variables | |
| load_dotenv() | |
| MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD")) | |
| MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME") | |
| MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME") | |
| MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING") | |
| app = FastAPI(title="VectorStore & Document Management API") | |
| # Note: Since user avatars are now stored in MongoDB via GridFS and served via /auth/avatar, | |
| # we no longer mount a local avatars directory. | |
| # Import auth router and dependencies | |
| from auth import router as auth_router, get_current_user, users_collection | |
| # Mount auth endpoints under /auth | |
| app.include_router(auth_router, prefix="/auth") | |
| from transcribe import router as transcribe_router | |
| app.include_router(transcribe_router, prefix="/audio") | |
| # Global variables (initialized on startup) | |
| llm = None | |
| embeddings = None | |
| chat_manager = None | |
| document_loader = None | |
| text_splitter = None | |
| vector_store_manager = None | |
| vector_store = None | |
| k = 3 # Number of documents to retrieve per query | |
| # ----------------------- Startup Event ----------------------- | |
| async def startup_event(): | |
| global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store | |
| print("Starting up: Initializing components...") | |
| # Initialize LLM and embeddings | |
| llm = get_llm() | |
| print("LLM initialized.") | |
| embeddings = get_embeddings() | |
| print("Embeddings initialized.") | |
| # Setup chat management | |
| chat_manager = ChatManagement( | |
| cluster_url=MONGO_CLUSTER_URL, | |
| database_name=MONGO_DATABASE_NAME, | |
| collection_name=MONGO_COLLECTION_NAME, | |
| ) | |
| print("Chat management initialized.") | |
| # Initialize document loader and text splitter | |
| document_loader = DocumentLoader() | |
| text_splitter = TextSplitter() | |
| print("Document loader and text splitter initialized.") | |
| # Initialize vector store manager and set vector store | |
| vector_store_manager = VectorStoreManager(embeddings) | |
| vector_store = vector_store_manager.vectorstore | |
| print("Vector store initialized.") | |
| # ----------------------- New Chat Endpoint (Updated) ----------------------- | |
| def new_chat(current_user: dict = Depends(get_current_user)): | |
| """ | |
| Create a new chat session under the current user's document. | |
| """ | |
| new_chat_id = str(uuid.uuid4()) | |
| # Append a new chat session to the user's chat_histories | |
| users_collection.update_one( | |
| {"email": current_user["email"]}, | |
| {"$push": {"chat_histories": {"chat_id": new_chat_id, "created_at": datetime.utcnow(), "messages": []}}} | |
| ) | |
| return {"chat_id": new_chat_id} | |
| # ----------------------- Create Chain Endpoint (Updated) ----------------------- | |
| def create_chain( | |
| chat_id: str = Query(..., description="Existing chat session ID"), | |
| template: str = Query( | |
| "quiz_solving", | |
| description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation", | |
| ), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| valid_templates = [ | |
| "quiz_solving", | |
| "assignment_solving", | |
| "paper_solving", | |
| "quiz_creation", | |
| "assignment_creation", | |
| "paper_creation", | |
| ] | |
| if template not in valid_templates: | |
| raise HTTPException(status_code=400, detail="Invalid template selection.") | |
| # Update the specific chat session's configuration in the user's document | |
| users_collection.update_one( | |
| {"email": current_user["email"], "chat_histories.chat_id": chat_id}, | |
| {"$set": {"chat_histories.$.template": template}} | |
| ) | |
| return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template} | |
| # ----------------------- Chat Endpoint ----------------------- | |
| def chat( | |
| query: str, | |
| chat_id: str = Query(..., description="Chat session ID"), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """ | |
| Process a chat query using the retrieval chain associated with the given chat_id. | |
| """ | |
| # Retrieve chat configuration from the user's document | |
| user = current_user | |
| chat_config = None | |
| for chat in user.get("chat_histories", []): | |
| if chat.get("chat_id") == chat_id: | |
| chat_config = chat | |
| break | |
| if not chat_config: | |
| raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.") | |
| template = chat_config.get("template", "quiz_solving") | |
| if template == "quiz_solving": | |
| prompt = PromptTemplates.get_quiz_solving_prompt() | |
| elif template == "assignment_solving": | |
| prompt = PromptTemplates.get_assignment_solving_prompt() | |
| elif template == "paper_solving": | |
| prompt = PromptTemplates.get_paper_solving_prompt() | |
| elif template == "quiz_creation": | |
| prompt = PromptTemplates.get_quiz_creation_prompt() | |
| elif template == "assignment_creation": | |
| prompt = PromptTemplates.get_assignment_creation_prompt() | |
| elif template == "paper_creation": | |
| prompt = PromptTemplates.get_paper_creation_prompt() | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid chat configuration.") | |
| retrieval_chain = RetrievalChain( | |
| llm, | |
| vector_store.as_retriever(search_kwargs={"k": k}), | |
| prompt, | |
| verbose=True, | |
| ) | |
| try: | |
| stream_generator = retrieval_chain.stream_chat_response( | |
| query=query, | |
| chat_id=chat_id, | |
| get_chat_history=chat_manager.get_chat_history, | |
| initialize_chat_history=chat_manager.initialize_chat_history, | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}") | |
| return StreamingResponse(stream_generator, media_type="text/event-stream") | |
| # ----------------------- Remaining Endpoints ----------------------- | |
| async def add_document( | |
| file: Optional[UploadFile] = File(None), # File parameter now is an UploadFile | |
| wiki_query: Optional[str] = Query(None), | |
| wiki_url: Optional[str] = Query(None) | |
| ): | |
| if file is None and wiki_query is None and wiki_url is None: | |
| raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).") | |
| if file is not None: | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| contents = await file.read() | |
| tmp.write(contents) | |
| tmp_filename = tmp.name | |
| ext = file.filename.split(".")[-1].lower() | |
| try: | |
| if ext == "pdf": | |
| documents = document_loader.load_pdf(tmp_filename) | |
| elif ext == "csv": | |
| documents = document_loader.load_csv(tmp_filename) | |
| elif ext in ["doc", "docx"]: | |
| documents = document_loader.load_doc(tmp_filename) | |
| elif ext in ["html", "htm"]: | |
| documents = document_loader.load_text_from_html(tmp_filename) | |
| elif ext in ["md", "markdown"]: | |
| documents = document_loader.load_markdown(tmp_filename) | |
| else: | |
| documents = document_loader.load_unstructured(tmp_filename) | |
| except Exception as e: | |
| os.remove(tmp_filename) | |
| raise HTTPException(status_code=400, detail=f"Error loading document from file: {str(e)}") | |
| os.remove(tmp_filename) | |
| elif wiki_query is not None: | |
| try: | |
| documents = document_loader.wikipedia_query(wiki_query) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error loading Wikipedia query: {str(e)}") | |
| elif wiki_url is not None: | |
| try: | |
| documents = document_loader.load_urls([wiki_url]) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error loading URL: {str(e)}") | |
| try: | |
| chunks = text_splitter.split_documents(documents) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error splitting document: {str(e)}") | |
| try: | |
| ids = vector_store_manager.add_documents(chunks) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error indexing document chunks: {str(e)}") | |
| return {"message": f"Added {len(chunks)} document chunks.", "ids": ids} | |
| def delete_document(ids: List[str]): | |
| try: | |
| success = vector_store_manager.delete_documents(ids) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}") | |
| if not success: | |
| raise HTTPException(status_code=400, detail="Failed to delete documents.") | |
| return {"message": f"Deleted documents with IDs: {ids}"} | |
| def save_vectorstore(): | |
| try: | |
| save_result = vector_store_manager.save("faiss_index") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error saving vectorstore: {str(e)}") | |
| return FileResponse( | |
| path=save_result["file_path"], | |
| media_type=save_result["media_type"], | |
| filename=save_result["serve_filename"], | |
| ) | |
| async def load_vectorstore(file: UploadFile = File(...)): | |
| tmp_filename = None | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| file_bytes = await file.read() | |
| tmp.write(file_bytes) | |
| tmp_filename = tmp.name | |
| instance, message = VectorStoreManager.load(tmp_filename, embeddings) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}") | |
| finally: | |
| if tmp_filename and os.path.exists(tmp_filename): | |
| os.remove(tmp_filename) | |
| global vector_store_manager | |
| vector_store_manager = instance | |
| return {"message": message} | |
| async def merge_vectorstore(file: UploadFile = File(...)): | |
| tmp_filename = None | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| file_bytes = await file.read() | |
| tmp.write(file_bytes) | |
| tmp_filename = tmp.name | |
| result = vector_store_manager.merge(tmp_filename, embeddings) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}") | |
| finally: | |
| if tmp_filename and os.path.exists(tmp_filename): | |
| os.remove(tmp_filename) | |
| return result | |
| async def root(): | |
| """ | |
| Root endpoint that provides a welcome message. | |
| """ | |
| return { | |
| "message": "Welcome to the EduLearn AI." | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |