File size: 2,836 Bytes
149adfe
6563a3d
c220f45
 
9000d27
 
 
 
 
 
c220f45
 
 
 
 
21633f2
 
 
9000d27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a852a21
 
 
 
163cbb2
a852a21
163cbb2
 
a852a21
 
 
 
 
 
 
 
 
 
9000d27
 
 
 
 
 
 
 
 
 
 
a852a21
6563a3d
 
 
9000d27
 
e39e8ae
9000d27
e39e8ae
 
2f45137
9000d27
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os, subprocess, time
from datetime import datetime, timedelta, timezone
from fastapi import FastAPI

from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker

app = FastAPI()

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.get("/test")
def greet_json():
    return {"test": "test successful!"}

# Google Drive File ID (Replace with your actual ID)
GOOGLE_DRIVE_FILE_ID = "1oKaMXhc1Z9eyYODiYBiN-NVyB4eYJjbz"

# Define paths
CHROMA_PATH = "chroma_db"
ZIP_FILE = "chroma_db.zip"

reranker_model = None
embedding = None

def load_embedding_model(model_path : str):
    start_time = time.time()
    encode_kwargs = {"normalize_embeddings": True}
    local_embedding = HuggingFaceEmbeddings(
        model_name=model_path,
        cache_folder="./models",
        encode_kwargs=encode_kwargs
    )
    end_time = time.time()
    print(f'model load time {round(end_time - start_time, 0)} second')
    return local_embedding

if not reranker_model:
    reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
    print("reranker model loaded")
if not embedding:
    start_time = time.time()    
    embedding = load_embedding_model(model_path="intfloat/multilingual-e5-large")
    end_time = time.time()
    print(f'embedding model load time {round(end_time - start_time, 0)} second')
    print("embedding model loaded")

if not os.path.exists(CHROMA_PATH):
    print("Downloading ChromaDB from Google Drive...")
    subprocess.run(["gdown", f"https://drive.google.com/uc?id={GOOGLE_DRIVE_FILE_ID}", "-O", ZIP_FILE])
    subprocess.run(["unzip", ZIP_FILE])  # Extract database
    retriever = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding).as_retriever(search_kwargs={"k": 20})
    print("ChromaDB loaded!")


def rag_with_reranking(query : str):
    compressor = CrossEncoderReranker(model=reranker_model, top_n=3)
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor, base_retriever=retriever
    )
    results = compression_retriever.invoke(query)
    return results

@app.get("/search")
def search_text(query):
    """Searches for similar texts."""

    now_utc = datetime.now(timezone.utc)
    now = now_utc + timedelta(hours=7)
    print(now.strftime('%Y-%m-%d %H:%M:%S +07'))
    print(f"Searching for: {query}")

    start = time.time()
    results = rag_with_reranking(query)
    end = time.time()
    print("Search Time:", end - start, "seconds")

    lst = []
    for doc in results:
        lst.append(doc.page_content)
    return {"results": lst}