AdaptiveEngineService / app /utils /vectordatabase.py
Gaykar's picture
changes
602f88e
import json
import pickle
import requests
from pathlib import Path
from typing import List
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from app.core.config import settings
# -----------------------------
# Paths
# -----------------------------
BASE_DIR = Path(__file__).resolve().parent
DATA_PATH = BASE_DIR / "langchain_formatted.json"
BM25_PKL_PATH = BASE_DIR / "bm25.pkl"
# General Remote Embeddings
# avoids cold starts
class GeneralRemoteEmbeddings(Embeddings):
def __init__(self, endpoint: str):
self.endpoint = endpoint
def embed_documents(self, texts: List[str]) -> List[List[float]]:
response = requests.post(
f"{self.endpoint}/embed_docs",
json={"texts": texts}
)
response.raise_for_status()
return response.json()["embeddings"]
def embed_query(self, text: str) -> List[float]:
response = requests.post(
f"{self.endpoint}/embed_query",
json={"text": text}
)
response.raise_for_status()
return response.json()["embedding"]
embeddings = GeneralRemoteEmbeddings(
endpoint="https://gaykar-generalembeddings.hf.space"
)
# -----------------------------
# Load Documents
# -----------------------------
def load_documents(data_path: Path) -> List[Document]:
if not data_path.exists():
raise FileNotFoundError(f"Catalog file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
documents = [
Document(
page_content=doc["page_content"],
metadata=doc["metadata"]
)
for doc in data
]
print(f"Loaded {len(documents)} course documents")
return documents
documents: List[Document] = load_documents(DATA_PATH)
if not documents:
raise ValueError("No documents loaded from formatted_catalog.json")
# -----------------------------
# Pinecone Index
# -----------------------------
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
INDEX_NAME = "final-catalog-index"
if INDEX_NAME not in pc.list_indexes().names():
pc.create_index(
name=INDEX_NAME,
dimension=384,
metric="dotproduct",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
print(f"Index created: {INDEX_NAME}")
index = pc.Index(INDEX_NAME)
print("Index ready:", index.describe_index_stats())
# -----------------------------
# BM25 Sparse Encoder
# Loads from pickle if exists, fits and saves if not
# -----------------------------
bm25_encoder = BM25Encoder()
if BM25_PKL_PATH.exists():
print("Loading existing BM25 model from pickle...")
with open(BM25_PKL_PATH, "rb") as f:
bm25_encoder = pickle.load(f)
else:
print("Fitting BM25 on course catalog...")
bm25_encoder.fit([doc.page_content for doc in documents])
with open(BM25_PKL_PATH, "wb") as f:
pickle.dump(bm25_encoder, f)
print(f"BM25 fitted and saved to {BM25_PKL_PATH}")
# -----------------------------
# Hybrid Retriever
# -----------------------------
retriever = PineconeHybridSearchRetriever(
embeddings=embeddings,
sparse_encoder=bm25_encoder,
index=index
)
print("Retriever ready.")