File size: 3,453 Bytes
baea076
 
9bc7a47
baea076
 
73ee6c2
 
 
baea076
9bc7a47
baea076
73ee6c2
 
baea076
 
 
73ee6c2
baea076
16dad23
baea076
73ee6c2
 
2b38d79
9bc7a47
2b38d79
 
73ee6c2
9bc7a47
 
 
73ee6c2
9bc7a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ee6c2
baea076
9bc7a47
 
baea076
 
 
 
9bc7a47
baea076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ee6c2
baea076
 
73ee6c2
 
baea076
73ee6c2
baea076
 
 
 
 
 
 
 
 
 
2b38d79
baea076
 
73ee6c2
baea076
73ee6c2
 
 
 
 
 
 
baea076
73ee6c2
baea076
73ee6c2
 
baea076
 
 
 
 
 
73ee6c2
 
baea076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ee6c2
 
 
 
 
 
baea076
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
import pickle
import requests
from pathlib import Path
from typing import List
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from app.core.config import settings


# -----------------------------
# Paths
# -----------------------------

BASE_DIR = Path(__file__).resolve().parent
DATA_PATH = BASE_DIR / "langchain_formatted.json"
BM25_PKL_PATH = BASE_DIR / "bm25.pkl"



# General Remote Embeddings
# avoids cold starts


class GeneralRemoteEmbeddings(Embeddings):
    def __init__(self, endpoint: str):
        self.endpoint = endpoint

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        response = requests.post(
            f"{self.endpoint}/embed_docs",
            json={"texts": texts}
        )
        response.raise_for_status()
        return response.json()["embeddings"]

    def embed_query(self, text: str) -> List[float]:
        response = requests.post(
            f"{self.endpoint}/embed_query",
            json={"text": text}
        )
        response.raise_for_status()
        return response.json()["embedding"]


embeddings = GeneralRemoteEmbeddings(
    endpoint="https://gaykar-generalembeddings.hf.space"
)


# -----------------------------
# Load Documents
# -----------------------------

def load_documents(data_path: Path) -> List[Document]:
    if not data_path.exists():
        raise FileNotFoundError(f"Catalog file not found: {data_path}")

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    documents = [
        Document(
            page_content=doc["page_content"],
            metadata=doc["metadata"]
        )
        for doc in data
    ]

    print(f"Loaded {len(documents)} course documents")
    return documents


documents: List[Document] = load_documents(DATA_PATH)

if not documents:
    raise ValueError("No documents loaded from formatted_catalog.json")


# -----------------------------
# Pinecone Index
# -----------------------------

pc = Pinecone(api_key=settings.PINECONE_API_KEY)

INDEX_NAME = "final-catalog-index"

if INDEX_NAME not in pc.list_indexes().names():
    pc.create_index(
        name=INDEX_NAME,
        dimension=384,
        metric="dotproduct",
        spec=ServerlessSpec(
            cloud="aws",
            region="us-east-1"
        )
    )
    print(f"Index created: {INDEX_NAME}")

index = pc.Index(INDEX_NAME)
print("Index ready:", index.describe_index_stats())


# -----------------------------
# BM25 Sparse Encoder
# Loads from pickle if exists, fits and saves if not
# -----------------------------

bm25_encoder = BM25Encoder()

if BM25_PKL_PATH.exists():
    print("Loading existing BM25 model from pickle...")
    with open(BM25_PKL_PATH, "rb") as f:
        bm25_encoder = pickle.load(f)
else:
    print("Fitting BM25 on course catalog...")
    bm25_encoder.fit([doc.page_content for doc in documents])
    with open(BM25_PKL_PATH, "wb") as f:
        pickle.dump(bm25_encoder, f)
    print(f"BM25 fitted and saved to {BM25_PKL_PATH}")


# -----------------------------
# Hybrid Retriever
# -----------------------------

retriever = PineconeHybridSearchRetriever(
    embeddings=embeddings,
    sparse_encoder=bm25_encoder,
    index=index
)

print("Retriever ready.")