import os, subprocess, time from datetime import datetime, timedelta, timezone from fastapi import FastAPI from langchain_community.vectorstores import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.cross_encoders import HuggingFaceCrossEncoder from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import CrossEncoderReranker app = FastAPI() @app.get("/") def greet_json(): return {"Hello": "World!"} @app.get("/test") def greet_json(): return {"test": "test successful!"} # Google Drive File ID (Replace with your actual ID) GOOGLE_DRIVE_FILE_ID = "1oKaMXhc1Z9eyYODiYBiN-NVyB4eYJjbz" # Define paths CHROMA_PATH = "chroma_db" ZIP_FILE = "chroma_db.zip" reranker_model = None embedding = None def load_embedding_model(model_path : str): start_time = time.time() encode_kwargs = {"normalize_embeddings": True} local_embedding = HuggingFaceEmbeddings( model_name=model_path, cache_folder="./models", encode_kwargs=encode_kwargs ) end_time = time.time() print(f'model load time {round(end_time - start_time, 0)} second') return local_embedding if not reranker_model: reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") print("reranker model loaded") if not embedding: start_time = time.time() embedding = load_embedding_model(model_path="intfloat/multilingual-e5-large") end_time = time.time() print(f'embedding model load time {round(end_time - start_time, 0)} second') print("embedding model loaded") if not os.path.exists(CHROMA_PATH): print("Downloading ChromaDB from Google Drive...") subprocess.run(["gdown", f"https://drive.google.com/uc?id={GOOGLE_DRIVE_FILE_ID}", "-O", ZIP_FILE]) subprocess.run(["unzip", ZIP_FILE]) # Extract database retriever = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding).as_retriever(search_kwargs={"k": 20}) print("ChromaDB loaded!") def rag_with_reranking(query : str): compressor = CrossEncoderReranker(model=reranker_model, top_n=3) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever ) results = compression_retriever.invoke(query) return results @app.get("/search") def search_text(query): """Searches for similar texts.""" now_utc = datetime.now(timezone.utc) now = now_utc + timedelta(hours=7) print(now.strftime('%Y-%m-%d %H:%M:%S +07')) print(f"Searching for: {query}") start = time.time() results = rag_with_reranking(query) end = time.time() print("Search Time:", end - start, "seconds") lst = [] for doc in results: lst.append(doc.page_content) return {"results": lst}