Spaces:
Paused
Paused
| import os, subprocess, time | |
| from datetime import datetime, timedelta, timezone | |
| from fastapi import FastAPI | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| app = FastAPI() | |
| def greet_json(): | |
| return {"Hello": "World!"} | |
| def greet_json(): | |
| return {"test": "test successful!"} | |
| # Google Drive File ID (Replace with your actual ID) | |
| GOOGLE_DRIVE_FILE_ID = "1oKaMXhc1Z9eyYODiYBiN-NVyB4eYJjbz" | |
| # Define paths | |
| CHROMA_PATH = "chroma_db" | |
| ZIP_FILE = "chroma_db.zip" | |
| reranker_model = None | |
| embedding = None | |
| def load_embedding_model(model_path : str): | |
| start_time = time.time() | |
| encode_kwargs = {"normalize_embeddings": True} | |
| local_embedding = HuggingFaceEmbeddings( | |
| model_name=model_path, | |
| cache_folder="./models", | |
| encode_kwargs=encode_kwargs | |
| ) | |
| end_time = time.time() | |
| print(f'model load time {round(end_time - start_time, 0)} second') | |
| return local_embedding | |
| if not reranker_model: | |
| reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") | |
| print("reranker model loaded") | |
| if not embedding: | |
| start_time = time.time() | |
| embedding = load_embedding_model(model_path="intfloat/multilingual-e5-large") | |
| end_time = time.time() | |
| print(f'embedding model load time {round(end_time - start_time, 0)} second') | |
| print("embedding model loaded") | |
| if not os.path.exists(CHROMA_PATH): | |
| print("Downloading ChromaDB from Google Drive...") | |
| subprocess.run(["gdown", f"https://drive.google.com/uc?id={GOOGLE_DRIVE_FILE_ID}", "-O", ZIP_FILE]) | |
| subprocess.run(["unzip", ZIP_FILE]) # Extract database | |
| retriever = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding).as_retriever(search_kwargs={"k": 20}) | |
| print("ChromaDB loaded!") | |
| def rag_with_reranking(query : str): | |
| compressor = CrossEncoderReranker(model=reranker_model, top_n=3) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, base_retriever=retriever | |
| ) | |
| results = compression_retriever.invoke(query) | |
| return results | |
| def search_text(query): | |
| """Searches for similar texts.""" | |
| now_utc = datetime.now(timezone.utc) | |
| now = now_utc + timedelta(hours=7) | |
| print(now.strftime('%Y-%m-%d %H:%M:%S +07')) | |
| print(f"Searching for: {query}") | |
| start = time.time() | |
| results = rag_with_reranking(query) | |
| end = time.time() | |
| print("Search Time:", end - start, "seconds") | |
| lst = [] | |
| for doc in results: | |
| lst.append(doc.page_content) | |
| return {"results": lst} |