Spaces:
Sleeping
Sleeping
Commit ·
c674913
0
Parent(s):
chore(vectorstore): create a vectorstor on amazon products dataset.
Browse filesfeat(retriever): Implement a retriever with hybrid search and reranker
- .gitignore +15 -0
- .python-version +1 -0
- README.md +0 -0
- bm25_encoder.json +0 -0
- build_vectorstore.py +209 -0
- main.py +9 -0
- pyproject.toml +19 -0
- retriever.py +286 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
.env
|
| 12 |
+
|
| 13 |
+
# Jupyter notebooks
|
| 14 |
+
*.ipynb
|
| 15 |
+
.ipynb_checkpoints
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
README.md
ADDED
|
File without changes
|
bm25_encoder.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build_vectorstore.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
from pinecone.grpc import PineconeGRPC as Pinecone
|
| 5 |
+
from pinecone import ServerlessSpec
|
| 6 |
+
from pinecone_text.sparse import BM25Encoder
|
| 7 |
+
from langchain_openai import OpenAIEmbeddings
|
| 8 |
+
import uuid
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
_ = load_dotenv()
|
| 13 |
+
|
| 14 |
+
class PineconeHybridProductIndexer:
|
| 15 |
+
def __init__(self, index_name: str, api_key: str, environment: str = "us-east-1"):
|
| 16 |
+
"""Initialize Pinecone hybrid search for products"""
|
| 17 |
+
self.pc = Pinecone(api_key=api_key)
|
| 18 |
+
self.environment = environment
|
| 19 |
+
self.index_name = index_name
|
| 20 |
+
|
| 21 |
+
# Initialize embeddings model
|
| 22 |
+
self.dense_model = OpenAIEmbeddings(model="text-embedding-3-large")
|
| 23 |
+
self.dimensions = 3072
|
| 24 |
+
|
| 25 |
+
# Initialize sparse encoder (BM25)
|
| 26 |
+
self.sparse_encoder = BM25Encoder()
|
| 27 |
+
|
| 28 |
+
self.index = None
|
| 29 |
+
|
| 30 |
+
def create_hybrid_index(self):
|
| 31 |
+
"""
|
| 32 |
+
Create Pinecone hybrid index for products
|
| 33 |
+
Key requirement: vector_type='dense' and metric='dotproduct' for hybrid search
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
# Delete index if it exists
|
| 37 |
+
if self.index_name in self.pc.list_indexes().names():
|
| 38 |
+
print(f"Deleting existing index: {self.index_name}")
|
| 39 |
+
self.pc.delete_index(self.index_name)
|
| 40 |
+
|
| 41 |
+
# Create hybrid index
|
| 42 |
+
print(f"Creating index: {self.index_name}")
|
| 43 |
+
self.pc.create_index(
|
| 44 |
+
name=self.index_name,
|
| 45 |
+
dimension=self.dimensions,
|
| 46 |
+
metric="dotproduct", # Required for hybrid search
|
| 47 |
+
spec=ServerlessSpec(
|
| 48 |
+
cloud="aws",
|
| 49 |
+
region=self.environment
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Connect to index
|
| 54 |
+
self.index = self.pc.Index(self.index_name)
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error creating index: {e}")
|
| 58 |
+
raise
|
| 59 |
+
|
| 60 |
+
def connect_to_index(self):
|
| 61 |
+
"""Connect to existing index"""
|
| 62 |
+
if self.index_name not in self.pc.list_indexes().names():
|
| 63 |
+
raise ValueError(f"Index {self.index_name} does not exist. Create it first.")
|
| 64 |
+
|
| 65 |
+
self.index = self.pc.Index(self.index_name)
|
| 66 |
+
print(f"Connected to index: {self.index_name}")
|
| 67 |
+
|
| 68 |
+
def delete_index(self):
|
| 69 |
+
"""Delete the Pinecone index if it exists"""
|
| 70 |
+
try:
|
| 71 |
+
existing_indexes = self.pc.list_indexes().names()
|
| 72 |
+
if self.index_name in existing_indexes:
|
| 73 |
+
print(f"Deleting index: {self.index_name}")
|
| 74 |
+
self.pc.delete_index(self.index_name)
|
| 75 |
+
self.index = None
|
| 76 |
+
print(f"Index deleted: {self.index_name}")
|
| 77 |
+
else:
|
| 78 |
+
print(f"Index {self.index_name} does not exist; nothing to delete.")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Error deleting index: {e}")
|
| 81 |
+
raise
|
| 82 |
+
|
| 83 |
+
def prepare_documents_for_indexing(self, df: pd.DataFrame) -> List[Dict]:
|
| 84 |
+
"""Prepare documents for hybrid indexing"""
|
| 85 |
+
print("Preparing documents for hybrid indexing...")
|
| 86 |
+
|
| 87 |
+
# Prepare texts for sparse encoding
|
| 88 |
+
texts = []
|
| 89 |
+
documents = []
|
| 90 |
+
|
| 91 |
+
for _, row in df.iterrows():
|
| 92 |
+
# Create rich text content for both dense and sparse encoding
|
| 93 |
+
text_content = f"Product: {row['name']}. Category: {row['main_category']}. Type: {row['sub_category']}."
|
| 94 |
+
texts.append(text_content)
|
| 95 |
+
|
| 96 |
+
# Prepare metadata
|
| 97 |
+
metadata = {
|
| 98 |
+
'name': row['name'],
|
| 99 |
+
'main_category': row['main_category'],
|
| 100 |
+
'sub_category': row['sub_category'],
|
| 101 |
+
'discount_price_usd': float(row['discount_price_usd']),
|
| 102 |
+
'actual_price_usd': float(row['actual_price_usd']),
|
| 103 |
+
'ratings': float(row['ratings']),
|
| 104 |
+
'no_of_ratings': int(row['no_of_ratings']),
|
| 105 |
+
'image': row['image'],
|
| 106 |
+
'link': row['link']
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
documents.append({
|
| 110 |
+
'id': str(uuid.uuid4()),
|
| 111 |
+
'text': text_content,
|
| 112 |
+
'metadata': metadata
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
# Fit sparse encoder on all texts
|
| 116 |
+
print("Training BM25 sparse encoder...")
|
| 117 |
+
self.sparse_encoder.fit(texts)
|
| 118 |
+
|
| 119 |
+
# Save sparse encoder
|
| 120 |
+
print("Saving BM25 sparse encoder...")
|
| 121 |
+
self.sparse_encoder.dump("bm25_encoder.json")
|
| 122 |
+
|
| 123 |
+
return documents
|
| 124 |
+
|
| 125 |
+
def index_products(self, df: pd.DataFrame, batch_size: int = 100):
|
| 126 |
+
"""Index products in Pinecone with hybrid vectors"""
|
| 127 |
+
print(f"Starting to index {len(df)} products...")
|
| 128 |
+
|
| 129 |
+
# Prepare documents (fits BM25 across the whole corpus and builds metadata)
|
| 130 |
+
documents = self.prepare_documents_for_indexing(df)
|
| 131 |
+
|
| 132 |
+
# Embed and upsert in batches to avoid holding all vectors in memory
|
| 133 |
+
total_docs = len(documents)
|
| 134 |
+
total_batches = (total_docs + batch_size - 1) // batch_size
|
| 135 |
+
max_retries = 5
|
| 136 |
+
base_delay_seconds = 1.0
|
| 137 |
+
|
| 138 |
+
with tqdm(total=total_batches, desc="Upserting batches", unit="batch") as pbar:
|
| 139 |
+
for i in range(0, total_docs, batch_size):
|
| 140 |
+
batch_num = i // batch_size + 1
|
| 141 |
+
batch_docs = documents[i:i + batch_size]
|
| 142 |
+
start_idx = i + 1
|
| 143 |
+
end_idx = min(i + len(batch_docs), total_docs)
|
| 144 |
+
pbar.set_postfix_str(f"batch {batch_num}/{total_batches} items {start_idx}-{end_idx}")
|
| 145 |
+
|
| 146 |
+
# Prepare texts
|
| 147 |
+
batch_texts = [doc['text'] for doc in batch_docs]
|
| 148 |
+
|
| 149 |
+
# Create dense and sparse vectors for this batch
|
| 150 |
+
dense_vectors = self.dense_model.embed_documents(batch_texts)
|
| 151 |
+
sparse_vectors = self.sparse_encoder.encode_documents(batch_texts)
|
| 152 |
+
|
| 153 |
+
# Build Pinecone vector payloads
|
| 154 |
+
batch_vectors = []
|
| 155 |
+
for j, doc in enumerate(batch_docs):
|
| 156 |
+
batch_vectors.append({
|
| 157 |
+
'id': doc['id'],
|
| 158 |
+
'values': dense_vectors[j],
|
| 159 |
+
'sparse_values': {
|
| 160 |
+
'indices': sparse_vectors[j]['indices'],
|
| 161 |
+
'values': sparse_vectors[j]['values']
|
| 162 |
+
},
|
| 163 |
+
'metadata': doc['metadata']
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
# Upsert with retries
|
| 167 |
+
last_error = None
|
| 168 |
+
for attempt in range(1, max_retries + 1):
|
| 169 |
+
try:
|
| 170 |
+
self.index.upsert(vectors=batch_vectors)
|
| 171 |
+
last_error = None
|
| 172 |
+
break
|
| 173 |
+
except Exception as e:
|
| 174 |
+
last_error = e
|
| 175 |
+
if attempt < max_retries:
|
| 176 |
+
delay = base_delay_seconds * (2 ** (attempt - 1))
|
| 177 |
+
tqdm.write(f"[Batch {batch_num}/{total_batches}] Attempt {attempt} failed: {e}. Retrying in {delay:.1f}s...")
|
| 178 |
+
import time
|
| 179 |
+
time.sleep(delay)
|
| 180 |
+
else:
|
| 181 |
+
tqdm.write(f"[Batch {batch_num}/{total_batches}] Failed after {max_retries} attempts: {e}")
|
| 182 |
+
if last_error is not None:
|
| 183 |
+
raise last_error
|
| 184 |
+
|
| 185 |
+
pbar.update(1)
|
| 186 |
+
|
| 187 |
+
print(f"Successfully indexed {total_docs} products!")
|
| 188 |
+
stats = self.index.describe_index_stats()
|
| 189 |
+
print(f"Index stats: {stats}")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def setup_and_run():
|
| 193 |
+
"""Example usage of the hybrid search system"""
|
| 194 |
+
|
| 195 |
+
# Initialize retriever
|
| 196 |
+
retriever = PineconeHybridProductIndexer(
|
| 197 |
+
index_name="amazon-products-catalog",
|
| 198 |
+
api_key=os.getenv("PINECONE_API_KEY")
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Create index (do this once)
|
| 202 |
+
retriever.create_hybrid_index()
|
| 203 |
+
|
| 204 |
+
# Load and index your data (do this once)
|
| 205 |
+
df = pd.read_csv("data/amazon_products.csv")
|
| 206 |
+
retriever.index_products(df)
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
setup_and_run()
|
main.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agno.agent import Agent
|
| 2 |
+
from agno.models.openai import OpenAIChat
|
| 3 |
+
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
agent = Agent(model=OpenAIChat(id="gpt-4o-mini"))
|
| 8 |
+
|
| 9 |
+
agent.cli_app("Tell me a 5 second short story about a robot", stream=True)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "shopping-ai-agent"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"agno>=2.0.8",
|
| 9 |
+
"cohere>=5.18.0",
|
| 10 |
+
"google-genai>=1.38.0",
|
| 11 |
+
"ipykernel>=6.30.1",
|
| 12 |
+
"langchain-community>=0.3.29",
|
| 13 |
+
"langchain-openai>=0.3.33",
|
| 14 |
+
"pandas>=2.3.2",
|
| 15 |
+
"pinecone>=7.3.0",
|
| 16 |
+
"pinecone-client[grpc]>=6.0.0",
|
| 17 |
+
"pinecone-text>=0.11.0",
|
| 18 |
+
"python-dotenv>=1.1.1",
|
| 19 |
+
]
|
retriever.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Dict, Optional, Tuple
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import yaml
|
| 7 |
+
import time
|
| 8 |
+
from functools import wraps
|
| 9 |
+
|
| 10 |
+
from pinecone.grpc import PineconeGRPC as Pinecone
|
| 11 |
+
from pinecone_text.sparse import BM25Encoder
|
| 12 |
+
from openai import OpenAI
|
| 13 |
+
import cohere
|
| 14 |
+
|
| 15 |
+
_ = load_dotenv()
|
| 16 |
+
|
| 17 |
+
# Pydantic Models
|
| 18 |
+
class FilterModel(BaseModel):
|
| 19 |
+
"""Search filters with validation"""
|
| 20 |
+
min_price: Optional[float] = Field(None, ge=0)
|
| 21 |
+
max_price: Optional[float] = Field(None, ge=0)
|
| 22 |
+
categories: Optional[List[str]] = None
|
| 23 |
+
min_rating: Optional[float] = Field(None, ge=0, le=5)
|
| 24 |
+
min_reviews: Optional[int] = Field(None, ge=0)
|
| 25 |
+
|
| 26 |
+
class ProductItem(BaseModel):
|
| 27 |
+
"""Product model"""
|
| 28 |
+
name: str
|
| 29 |
+
price: float
|
| 30 |
+
original_price: float
|
| 31 |
+
rating: float
|
| 32 |
+
num_reviews: int
|
| 33 |
+
category: str
|
| 34 |
+
sub_category: str
|
| 35 |
+
image_url: str
|
| 36 |
+
link: str
|
| 37 |
+
|
| 38 |
+
def timer(func):
|
| 39 |
+
"""Decorator to measure function execution time"""
|
| 40 |
+
@wraps(func)
|
| 41 |
+
def wrapper(*args, **kwargs):
|
| 42 |
+
start_time = time.time()
|
| 43 |
+
result = func(*args, **kwargs)
|
| 44 |
+
execution_time = time.time() - start_time
|
| 45 |
+
print(f"{func.__name__} executed in {execution_time:.3f}s")
|
| 46 |
+
return result
|
| 47 |
+
return wrapper
|
| 48 |
+
|
| 49 |
+
class PineconeHybridRetriever:
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
index_name: str,
|
| 53 |
+
embedding_model: str = "text-embedding-3-large",
|
| 54 |
+
embedding_dimensions: int = 3072,
|
| 55 |
+
rerank_model: str = "rerank-v3.5",
|
| 56 |
+
bm25_encoder_path: str = "bm25_encoder.json",
|
| 57 |
+
environment: str = "us-east-1"
|
| 58 |
+
):
|
| 59 |
+
"""Initialize Pinecone hybrid search for products"""
|
| 60 |
+
self.index_name = index_name
|
| 61 |
+
self.embedding_model = embedding_model
|
| 62 |
+
self.embedding_dimensions = embedding_dimensions
|
| 63 |
+
self.rerank_model = rerank_model
|
| 64 |
+
self.bm25_encoder_path = bm25_encoder_path
|
| 65 |
+
|
| 66 |
+
self._initialize_clients(environment)
|
| 67 |
+
|
| 68 |
+
# Initialize encoders
|
| 69 |
+
self._initialize_encoders()
|
| 70 |
+
|
| 71 |
+
def _initialize_clients(self, environment: str) -> None:
|
| 72 |
+
"""Initialize external service clients"""
|
| 73 |
+
try:
|
| 74 |
+
# Initialize Pinecone
|
| 75 |
+
self.pc = Pinecone()
|
| 76 |
+
self.index = self.pc.Index(self.index_name)
|
| 77 |
+
|
| 78 |
+
# Initialize OpenAI
|
| 79 |
+
self.openai_client = OpenAI()
|
| 80 |
+
|
| 81 |
+
# Initialize Cohere
|
| 82 |
+
self.cohere_client = cohere.ClientV2()
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Failed to initialize clients: {e}")
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
def _initialize_encoders(self) -> None:
|
| 89 |
+
"""Initialize sparse encoder"""
|
| 90 |
+
try:
|
| 91 |
+
self.sparse_encoder = BM25Encoder().load(self.bm25_encoder_path)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"Failed to load BM25 encoder: {e}")
|
| 94 |
+
raise
|
| 95 |
+
|
| 96 |
+
def _get_dense_embedding(self, query: str) -> List[float]:
|
| 97 |
+
"""Generate dense embedding for query"""
|
| 98 |
+
response = self.openai_client.embeddings.create(
|
| 99 |
+
input=query,
|
| 100 |
+
model=self.embedding_model
|
| 101 |
+
)
|
| 102 |
+
return response.data[0].embedding
|
| 103 |
+
|
| 104 |
+
def _get_sparse_encoding(self, query: str) -> Dict[str, List]:
|
| 105 |
+
"""Generate sparse encoding for query"""
|
| 106 |
+
return self.sparse_encoder.encode_queries(query)
|
| 107 |
+
|
| 108 |
+
@timer
|
| 109 |
+
def _execute_parallel_encoding(self, query: str) -> Tuple[List[float], Dict[str, List]]:
|
| 110 |
+
"""Execute dense and sparse encoding in parallel"""
|
| 111 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 112 |
+
# Submit both encoding tasks
|
| 113 |
+
dense_future = executor.submit(self._get_dense_embedding, query)
|
| 114 |
+
sparse_future = executor.submit(self._get_sparse_encoding, query)
|
| 115 |
+
|
| 116 |
+
# Wait for completion
|
| 117 |
+
dense_embedding = dense_future.result()
|
| 118 |
+
sparse_encoding = sparse_future.result()
|
| 119 |
+
|
| 120 |
+
return dense_embedding, sparse_encoding
|
| 121 |
+
|
| 122 |
+
def _build_filter_conditions(self, filters: FilterModel) -> Dict:
|
| 123 |
+
"""Convert FilterModel to Pinecone filter format"""
|
| 124 |
+
conditions = {}
|
| 125 |
+
|
| 126 |
+
# Handle price range
|
| 127 |
+
if filters.min_price or filters.max_price:
|
| 128 |
+
price_cond = {}
|
| 129 |
+
if filters.min_price:
|
| 130 |
+
price_cond["$gte"] = filters.min_price
|
| 131 |
+
if filters.max_price:
|
| 132 |
+
price_cond["$lte"] = filters.max_price
|
| 133 |
+
conditions["discount_price_usd"] = price_cond
|
| 134 |
+
|
| 135 |
+
# Handle ratings, review count and categories
|
| 136 |
+
if filters.min_rating:
|
| 137 |
+
conditions["ratings"] = {"$gte": filters.min_rating}
|
| 138 |
+
|
| 139 |
+
if filters.min_reviews:
|
| 140 |
+
conditions["no_of_ratings"] = {"$gte": filters.min_reviews}
|
| 141 |
+
|
| 142 |
+
if filters.categories:
|
| 143 |
+
conditions["main_category"] = {"$in": filters.categories}
|
| 144 |
+
|
| 145 |
+
return conditions
|
| 146 |
+
|
| 147 |
+
def _convert_to_products(self, matches: List[Dict]) -> List[ProductItem]:
|
| 148 |
+
"""Convert search results to ProductItem objects"""
|
| 149 |
+
products = []
|
| 150 |
+
for match in matches:
|
| 151 |
+
metadata = match.get('metadata', {})
|
| 152 |
+
try:
|
| 153 |
+
product = ProductItem(
|
| 154 |
+
name=metadata['name'],
|
| 155 |
+
price=metadata['discount_price_usd'],
|
| 156 |
+
original_price=metadata['actual_price_usd'],
|
| 157 |
+
rating=metadata['ratings'],
|
| 158 |
+
num_reviews=metadata['no_of_ratings'],
|
| 159 |
+
category=metadata['main_category'],
|
| 160 |
+
sub_category=metadata['sub_category'],
|
| 161 |
+
image_url=metadata['image'],
|
| 162 |
+
link=metadata['link']
|
| 163 |
+
)
|
| 164 |
+
products.append(product)
|
| 165 |
+
except KeyError as e:
|
| 166 |
+
print(f"Missing metadata field: {e}")
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
return products
|
| 170 |
+
|
| 171 |
+
@timer
|
| 172 |
+
def _rerank_products(
|
| 173 |
+
self,
|
| 174 |
+
query: str,
|
| 175 |
+
products: List[ProductItem],
|
| 176 |
+
top_n: int
|
| 177 |
+
) -> List[ProductItem]:
|
| 178 |
+
"""
|
| 179 |
+
Rerank products using Cohere reranker
|
| 180 |
+
"""
|
| 181 |
+
if not products:
|
| 182 |
+
return products
|
| 183 |
+
|
| 184 |
+
# Convert products to yaml format
|
| 185 |
+
yaml_docs = [yaml.dump(product, sort_keys=False) for product in products]
|
| 186 |
+
|
| 187 |
+
# Rerank products
|
| 188 |
+
response = self.cohere_client.rerank(
|
| 189 |
+
model=self.rerank_model,
|
| 190 |
+
query=query,
|
| 191 |
+
top_n=top_n,
|
| 192 |
+
documents=yaml_docs
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Return reranked products
|
| 196 |
+
return [products[result.index] for result in response.results]
|
| 197 |
+
|
| 198 |
+
@timer
|
| 199 |
+
def search_products(
|
| 200 |
+
self,
|
| 201 |
+
query: str,
|
| 202 |
+
filters: FilterModel = None,
|
| 203 |
+
limit: int = 10,
|
| 204 |
+
alpha: float = 0.5, # Balance between dense (1.0) and sparse (0.0)
|
| 205 |
+
use_hybrid_search: bool = True,
|
| 206 |
+
enable_reranking: bool = False,
|
| 207 |
+
) -> List[ProductItem]:
|
| 208 |
+
"""
|
| 209 |
+
Perform hybrid search for products
|
| 210 |
+
"""
|
| 211 |
+
try:
|
| 212 |
+
if use_hybrid_search:
|
| 213 |
+
dense_embedding, sparse_encoding = self._execute_parallel_encoding(query)
|
| 214 |
+
|
| 215 |
+
else:
|
| 216 |
+
dense_embedding = self._get_dense_embedding(query)
|
| 217 |
+
sparse_encoding = None
|
| 218 |
+
alpha = 1.0 # Force dense-only search
|
| 219 |
+
|
| 220 |
+
# Build filters
|
| 221 |
+
filter_conditions = None
|
| 222 |
+
if filters:
|
| 223 |
+
filter_conditions = self._build_filter_conditions(filters)
|
| 224 |
+
|
| 225 |
+
if enable_reranking:
|
| 226 |
+
# Double the limit for reranking so we have enough results to rerank
|
| 227 |
+
limit = limit * 3
|
| 228 |
+
|
| 229 |
+
# Prepare query arguments
|
| 230 |
+
query_args = {
|
| 231 |
+
"vector": dense_embedding,
|
| 232 |
+
"top_k": limit,
|
| 233 |
+
"include_metadata": True,
|
| 234 |
+
"filter": filter_conditions,
|
| 235 |
+
"alpha": alpha
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
if use_hybrid_search and sparse_encoding:
|
| 239 |
+
query_args["sparse_vector"] = sparse_encoding
|
| 240 |
+
|
| 241 |
+
# Perform search
|
| 242 |
+
results = self.index.query(**query_args)
|
| 243 |
+
|
| 244 |
+
# Convert results to ProductItem objects
|
| 245 |
+
products = self._convert_to_products(results['matches'])
|
| 246 |
+
|
| 247 |
+
# Apply reranking if requested
|
| 248 |
+
if enable_reranking and products:
|
| 249 |
+
products = self._rerank_products(query, products, top_n=limit//3) # Get only the specified limit of products
|
| 250 |
+
|
| 251 |
+
return products
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Error during search: {e}")
|
| 255 |
+
return []
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# Usage Example
|
| 259 |
+
def example_usage():
|
| 260 |
+
"""Example usage of the hybrid search system"""
|
| 261 |
+
|
| 262 |
+
# Initialize retriever
|
| 263 |
+
retriever = PineconeHybridRetriever(
|
| 264 |
+
index_name="amazon-products-catalog"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Example search
|
| 268 |
+
filters = FilterModel(
|
| 269 |
+
min_price=10,
|
| 270 |
+
max_price=20,
|
| 271 |
+
min_rating=4.0
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
results: list[ProductItem] = retriever.search_products(
|
| 275 |
+
query="Black men shirts for casual wear",
|
| 276 |
+
filters=filters,
|
| 277 |
+
limit=10,
|
| 278 |
+
use_hybrid_search=True,
|
| 279 |
+
enable_reranking=True
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
for i, product in enumerate(results, 1):
|
| 283 |
+
print(f"{i}. {product.name} - ${round(product.price, 2)} ({round(product.rating, 1)}⭐)")
|
| 284 |
+
|
| 285 |
+
if __name__ == "__main__":
|
| 286 |
+
example_usage()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|