Spaces:
Sleeping
Sleeping
File size: 3,453 Bytes
baea076 9bc7a47 baea076 73ee6c2 baea076 9bc7a47 baea076 73ee6c2 baea076 73ee6c2 baea076 16dad23 baea076 73ee6c2 2b38d79 9bc7a47 2b38d79 73ee6c2 9bc7a47 73ee6c2 9bc7a47 73ee6c2 baea076 9bc7a47 baea076 9bc7a47 baea076 73ee6c2 baea076 73ee6c2 baea076 73ee6c2 baea076 2b38d79 baea076 73ee6c2 baea076 73ee6c2 baea076 73ee6c2 baea076 73ee6c2 baea076 73ee6c2 baea076 73ee6c2 baea076 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import json
import pickle
import requests
from pathlib import Path
from typing import List
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from app.core.config import settings
# -----------------------------
# Paths
# -----------------------------
BASE_DIR = Path(__file__).resolve().parent
DATA_PATH = BASE_DIR / "langchain_formatted.json"
BM25_PKL_PATH = BASE_DIR / "bm25.pkl"
# General Remote Embeddings
# avoids cold starts
class GeneralRemoteEmbeddings(Embeddings):
def __init__(self, endpoint: str):
self.endpoint = endpoint
def embed_documents(self, texts: List[str]) -> List[List[float]]:
response = requests.post(
f"{self.endpoint}/embed_docs",
json={"texts": texts}
)
response.raise_for_status()
return response.json()["embeddings"]
def embed_query(self, text: str) -> List[float]:
response = requests.post(
f"{self.endpoint}/embed_query",
json={"text": text}
)
response.raise_for_status()
return response.json()["embedding"]
embeddings = GeneralRemoteEmbeddings(
endpoint="https://gaykar-generalembeddings.hf.space"
)
# -----------------------------
# Load Documents
# -----------------------------
def load_documents(data_path: Path) -> List[Document]:
if not data_path.exists():
raise FileNotFoundError(f"Catalog file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
documents = [
Document(
page_content=doc["page_content"],
metadata=doc["metadata"]
)
for doc in data
]
print(f"Loaded {len(documents)} course documents")
return documents
documents: List[Document] = load_documents(DATA_PATH)
if not documents:
raise ValueError("No documents loaded from formatted_catalog.json")
# -----------------------------
# Pinecone Index
# -----------------------------
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
INDEX_NAME = "final-catalog-index"
if INDEX_NAME not in pc.list_indexes().names():
pc.create_index(
name=INDEX_NAME,
dimension=384,
metric="dotproduct",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
print(f"Index created: {INDEX_NAME}")
index = pc.Index(INDEX_NAME)
print("Index ready:", index.describe_index_stats())
# -----------------------------
# BM25 Sparse Encoder
# Loads from pickle if exists, fits and saves if not
# -----------------------------
bm25_encoder = BM25Encoder()
if BM25_PKL_PATH.exists():
print("Loading existing BM25 model from pickle...")
with open(BM25_PKL_PATH, "rb") as f:
bm25_encoder = pickle.load(f)
else:
print("Fitting BM25 on course catalog...")
bm25_encoder.fit([doc.page_content for doc in documents])
with open(BM25_PKL_PATH, "wb") as f:
pickle.dump(bm25_encoder, f)
print(f"BM25 fitted and saved to {BM25_PKL_PATH}")
# -----------------------------
# Hybrid Retriever
# -----------------------------
retriever = PineconeHybridSearchRetriever(
embeddings=embeddings,
sparse_encoder=bm25_encoder,
index=index
)
print("Retriever ready.") |