maharani-chatbot / qdrant_tool.py
Shakeel401's picture
Update qdrant_tool.py
d89092c verified
import os
from dotenv import load_dotenv
from typing import Any, List, Dict, Optional
from agents import function_tool
from langchain_qdrant import QdrantVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from qdrant_client import QdrantClient, models
# ๐Ÿ”‘ Load environment variables
load_dotenv()
QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
if not QDRANT_URL or not QDRANT_API_KEY:
raise ValueError("โŒ Missing QDRANT_URL or QDRANT_API_KEY. Set them in your .env file.")
print(f"๐Ÿ”ง Qdrant settings loaded: URL={QDRANT_URL}, API_KEY={'SET' if QDRANT_API_KEY else 'MISSING'}")
# ๐Ÿš€ Embeddings
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# ๐Ÿš€ Qdrant client
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, prefer_grpc=True)
# ๐ŸŽฏ Vector store wrappers
vector_store = QdrantVectorStore(
client=client,
collection_name="maharani",
embedding=embedding_model,
)
vendor_store = QdrantVectorStore(
client=client,
collection_name="vendors", # โœ… use vendors collection name
embedding=embedding_model,
)
# ------------------------------------------------
# Helper: Format metadata for readability
# ------------------------------------------------
def format_metadata(metadata: dict) -> str:
"""Clean and unify metadata keys (handles nested dicts and lists)."""
parts = []
if vendor_name := metadata.get("vendor_name"):
parts.append(f"**Vendor Name:** {vendor_name}")
if category := metadata.get("category"):
parts.append(f"**Category:** {category}")
if website := metadata.get("website"):
parts.append(f"**Website:** {website}")
if portfolio := metadata.get("portfolio_url"):
parts.append(f"**Portfolio:** {portfolio}")
if doc_type := metadata.get("doc_type"):
parts.append(f"**Type:** {doc_type}")
if source := metadata.get("source"):
parts.append(f"**Source:** {source}")
vendors = metadata.get("vendors", {})
if vendors:
vendor_str = ", ".join(f"{k}: {v}" for k, v in vendors.items() if v)
if vendor_str:
parts.append(f"**Vendors:** {vendor_str}")
# Extra fields
for k, v in metadata.items():
if k not in {
"vendor_name", "category", "website", "portfolio_url",
"doc_type", "source", "vendors"
}:
parts.append(f"**{k.capitalize()}:** {v}")
return "\n".join(parts) if parts else "No metadata"
# ------------------------------------------------
# Tool 1: General Maharani Search
# ------------------------------------------------
@function_tool
def search_maharani(query: str, k: int = 5) -> str:
"""
Search Maharani Weddings vector database (general chunks).
Returns top `k` results with content + metadata.
"""
print(f"\n๐Ÿ” search_maharani called with query='{query}', k={k}")
try:
results = vector_store.similarity_search(query=query, k=k)
print(f"โœ… Retrieved {len(results)} results from Qdrant")
except Exception as e:
print(f"โŒ Error while querying Qdrant: {e}")
return f"Error: {e}"
if not results:
return "โš ๏ธ No results found."
formatted = []
for i, doc in enumerate(results, 1):
vendor_name = doc.metadata.get("vendor_name", "Untitled")
print(f"๐Ÿ“ Formatting result {i}: {vendor_name}")
meta_str = format_metadata(doc.metadata)
formatted.append(
f"### Result {i}\n\n"
f"**Content:** {doc.page_content.strip()}\n\n"
f"**Metadata:**\n{meta_str}\n"
)
return "\n\n".join(formatted)
# ------------------------------------------------
# Tool 2: Vendor Search with Filters + Certified Priority
# ------------------------------------------------
@function_tool
def search_vendors(
query: str,
k: int = 5,
category: Optional[str] = None,
location: Optional[str] = None,
certified: Optional[str] = None,
prioritize_certified: bool = True
) -> List[Dict[str, Any]]:
"""
Search vendors from the 'vendors' Qdrant collection with optional filters.
If no results are found using filters, retry search without filters.
"""
print(f"\n๐Ÿ” search_vendors called with query='{query}', k={k}, "
f"category={category}, location={location}, certified={certified}, "
f"prioritize_certified={prioritize_certified}")
# -----------------------------
# ๐Ÿงฉ Build dynamic filter
# -----------------------------
must_filters = []
if location and location.lower() not in {"none", "", "null"}:
must_filters.append(
models.FieldCondition(
key="metadata.location",
match=models.MatchValue(value=location.lower())
)
)
if category and category.lower() not in {"none", "", "null"}:
must_filters.append(
models.FieldCondition(
key="metadata.category",
match=models.MatchValue(value=category.lower())
)
)
if certified:
must_filters.append(
models.FieldCondition(
key="metadata.certified",
match=models.MatchValue(value=certified.lower())
)
)
qdrant_filter = models.Filter(must=must_filters) if must_filters else None
# -----------------------------
# ๐Ÿ” Perform vector search (with filters first)
# -----------------------------
def perform_search(active_filter):
try:
return vendor_store.similarity_search(
query=query,
k=15, # fetch more for reordering
filter=active_filter
)
except Exception as e:
print(f"โŒ Error while querying Qdrant vendors: {e}")
return []
results = perform_search(qdrant_filter)
print(f"โœ… Retrieved {len(results)} vendor results (with filters)")
# ๐Ÿ” Retry without filters if nothing found
if not results and qdrant_filter:
print("โš ๏ธ No results found with filters โ€” retrying without filters...")
results = perform_search(None)
print(f"๐Ÿ” Retrieved {len(results)} vendor results (without filters)")
if not results:
return [{"message": "โš ๏ธ No matching vendors found."}]
# -----------------------------
# ๐Ÿงฎ Separate certified / uncertified
# -----------------------------
certified_vendors, uncertified_vendors = [], []
seen = set()
for r in results:
vendor_name = r.metadata.get("vendor_name")
if not vendor_name or vendor_name in seen:
continue
seen.add(vendor_name)
is_certified = str(r.metadata.get("certified", "")).lower() == "true"
vendor_info = {
"vendor_name": vendor_name,
"category": r.metadata.get("category"),
"location": r.metadata.get("location"),
"certified": r.metadata.get("certified"),
"website": r.metadata.get("website"),
"email": r.metadata.get("email"),
"phone": r.metadata.get("phone"),
"portfolio_url": r.metadata.get("portfolio_url"),
"content": r.page_content[:200]
}
if is_certified:
certified_vendors.append(vendor_info)
else:
uncertified_vendors.append(vendor_info)
# -----------------------------
# ๐Ÿฅ‡ Prioritize certified vendors
# -----------------------------
if prioritize_certified:
final_results = certified_vendors[:k]
if len(final_results) < k:
final_results += uncertified_vendors[: (k - len(final_results))]
else:
final_results = (certified_vendors + uncertified_vendors)[:k]
print(f"๐Ÿ Returning {len(final_results)} vendors (certified prioritized: {prioritize_certified})")
return final_results