import json import pickle import requests from pathlib import Path from typing import List from pinecone import Pinecone, ServerlessSpec from pinecone_text.sparse import BM25Encoder from langchain_community.retrievers import PineconeHybridSearchRetriever from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from app.core.config import settings # ----------------------------- # Paths # ----------------------------- BASE_DIR = Path(__file__).resolve().parent DATA_PATH = BASE_DIR / "langchain_formatted.json" BM25_PKL_PATH = BASE_DIR / "bm25.pkl" # General Remote Embeddings # avoids cold starts class GeneralRemoteEmbeddings(Embeddings): def __init__(self, endpoint: str): self.endpoint = endpoint def embed_documents(self, texts: List[str]) -> List[List[float]]: response = requests.post( f"{self.endpoint}/embed_docs", json={"texts": texts} ) response.raise_for_status() return response.json()["embeddings"] def embed_query(self, text: str) -> List[float]: response = requests.post( f"{self.endpoint}/embed_query", json={"text": text} ) response.raise_for_status() return response.json()["embedding"] embeddings = GeneralRemoteEmbeddings( endpoint="https://gaykar-generalembeddings.hf.space" ) # ----------------------------- # Load Documents # ----------------------------- def load_documents(data_path: Path) -> List[Document]: if not data_path.exists(): raise FileNotFoundError(f"Catalog file not found: {data_path}") with open(data_path, "r", encoding="utf-8") as f: data = json.load(f) documents = [ Document( page_content=doc["page_content"], metadata=doc["metadata"] ) for doc in data ] print(f"Loaded {len(documents)} course documents") return documents documents: List[Document] = load_documents(DATA_PATH) if not documents: raise ValueError("No documents loaded from formatted_catalog.json") # ----------------------------- # Pinecone Index # ----------------------------- pc = Pinecone(api_key=settings.PINECONE_API_KEY) INDEX_NAME = "final-catalog-index" if INDEX_NAME not in pc.list_indexes().names(): pc.create_index( name=INDEX_NAME, dimension=384, metric="dotproduct", spec=ServerlessSpec( cloud="aws", region="us-east-1" ) ) print(f"Index created: {INDEX_NAME}") index = pc.Index(INDEX_NAME) print("Index ready:", index.describe_index_stats()) # ----------------------------- # BM25 Sparse Encoder # Loads from pickle if exists, fits and saves if not # ----------------------------- bm25_encoder = BM25Encoder() if BM25_PKL_PATH.exists(): print("Loading existing BM25 model from pickle...") with open(BM25_PKL_PATH, "rb") as f: bm25_encoder = pickle.load(f) else: print("Fitting BM25 on course catalog...") bm25_encoder.fit([doc.page_content for doc in documents]) with open(BM25_PKL_PATH, "wb") as f: pickle.dump(bm25_encoder, f) print(f"BM25 fitted and saved to {BM25_PKL_PATH}") # ----------------------------- # Hybrid Retriever # ----------------------------- retriever = PineconeHybridSearchRetriever( embeddings=embeddings, sparse_encoder=bm25_encoder, index=index ) print("Retriever ready.")