IotaCluster's picture
Update vector_stores/L_vecdB.py
e363d66 verified
import os
import json
import numpy as np
import time
from typing import List
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
SparseVectorParams,
Modifier,
MultiVectorConfig,
MultiVectorComparator,
HnswConfigDiff
)
# Fix import for both direct and module execution
import sys
import os
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from embedding import (
get_dense_embedding,
get_sparse_embedding,
get_late_embedding,
to_valid_qdrant_id
)
load_dotenv()
class LongTermDatabase:
def __init__(
self,
collection_prefix: str = "longterm_db",
vector_size: int = 768,
url: str = "https://df35413f-27c8-419d-aa89-4b3901514560.us-west-1-0.aws.cloud.qdrant.io",
api_key: str = os.getenv('QDRANT_API_KEY')
):
self.api_key = api_key or os.getenv('QDRANT_API_KEY')
if not self.api_key:
raise RuntimeError("Missing QDRANT_API_KEY environment variable.")
self.client = QdrantClient(url=url, api_key=self.api_key)
self.collection_name = "long_rag"
self.vector_size = vector_size
# Only dense vectors are used
self._ensure_collection()
def _ensure_collection(self):
existing = [c.name for c in self.client.get_collections().collections]
if self.collection_name not in existing:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config={
"dense": VectorParams(size=self.vector_size, distance=Distance.COSINE),
"late": VectorParams(
size=768,
distance=Distance.COSINE,
multivector_config=MultiVectorConfig(
comparator=MultiVectorComparator.MAX_SIM
),
hnsw_config=HnswConfigDiff(m=0)
)
}
)
def _batch_get_embeddings(self, docs: List[str]):
dense_embs = [get_dense_embedding(doc) for doc in docs]
late_embs = [get_late_embedding(doc) for doc in docs]
return list(zip(dense_embs, late_embs))
def add_data(self, json_file: str, max_chunk_chars: int = 1500):
"""
Add each object in a JSON file as its own document. If the file is a list, each item is a document.
If the file is a dict, each value is a document. If a document is too large, it is split into chunks.
For objectwise JSON (list of dicts), each dict is a document.
"""
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
ids, docs = [], []
# Objectwise: if data is a list of dicts, treat each dict as a document
file_prefix = os.path.splitext(os.path.basename(json_file))[0]
if isinstance(data, list) and all(isinstance(item, dict) for item in data):
for i, item in enumerate(data):
doc_json = json.dumps(item, ensure_ascii=False)
if len(doc_json) > max_chunk_chars:
n_chunks = (len(doc_json) + max_chunk_chars - 1) // max_chunk_chars
for j in range(n_chunks):
chunk = doc_json[j * max_chunk_chars : (j + 1) * max_chunk_chars]
ids.append(f"{file_prefix}_{i}_{j}")
docs.append(chunk)
else:
ids.append(f"{file_prefix}_{i}")
docs.append(doc_json)
elif isinstance(data, dict):
# If dict, treat each value as a document
for k, v in data.items():
doc_json = json.dumps(v, ensure_ascii=False)
if len(doc_json) > max_chunk_chars:
n_chunks = (len(doc_json) + max_chunk_chars - 1) // max_chunk_chars
for j in range(n_chunks):
chunk = doc_json[j * max_chunk_chars : (j + 1) * max_chunk_chars]
ids.append(f"{file_prefix}_{k}_{j}")
docs.append(chunk)
else:
ids.append(f"{file_prefix}_{k}")
docs.append(doc_json)
else:
# Fallback: treat the whole thing as one document
doc_json = json.dumps(data, ensure_ascii=False)
if len(doc_json) > max_chunk_chars:
n_chunks = (len(doc_json) + max_chunk_chars - 1) // max_chunk_chars
for j in range(n_chunks):
chunk = doc_json[j * max_chunk_chars : (j + 1) * max_chunk_chars]
ids.append(f"{file_prefix}_0_{j}")
docs.append(chunk)
else:
ids.append(f"{file_prefix}_0")
docs.append(doc_json)
print(f"Adding {len(docs)} document(s) to the database...")
emb_pairs = self._batch_get_embeddings(docs)
points = []
for i, doc_id in enumerate(ids):
dense_vec, late_vec = emb_pairs[i]
points.append(
PointStruct(
id=to_valid_qdrant_id(doc_id),
vector={"dense": dense_vec, "late": late_vec},
payload={"document": docs[i]}
)
)
# Use upsert in small batches to avoid timeouts
BATCH_SIZE = 4 # or even 2 if needed
for i in range(0, len(points), BATCH_SIZE):
batch = points[i:i+BATCH_SIZE]
self.client.upsert(
collection_name=self.collection_name,
points=batch
)
time.sleep(0.5) # short pause between batches
print(f"Indexed {len(points)} document(s).")
def smart_query(self, query_text: str, topk: int = 5, top_l: int = 5, use_late: bool = True, doc_search: bool = True) -> List[str]:
"""
Hybrid query: first prefetch with dense (topk), then rerank with late embedding (ColBERT-style) and return top_l.
If use_late is False, does dense-only search. If True, does dense prefetch + late rerank.
If doc_search is True, also filter by fuzzy/substring in the document (case-insensitive) after reranking.
"""
import re
dense_vec = get_dense_embedding(query_text)
late_vec = get_late_embedding(query_text)
# Vector search
if use_late:
from qdrant_client.models import Prefetch
results = self.client.query_points(
collection_name=self.collection_name,
prefetch=Prefetch(query=dense_vec, using="dense"),
query=late_vec,
using="late",
limit=topk,
with_payload=True
)
else:
results = self.client.search(
collection_name=self.collection_name,
query_vector=dense_vec,
limit=topk,
with_payload=True,
with_vectors=False
)
# Normalize Qdrant results to always be a list of ScoredPoint-like objects
points_list = None
if isinstance(results, tuple) and len(results) == 2 and isinstance(results[1], list):
points_list = results[1]
elif isinstance(results, list):
points_list = results
elif hasattr(results, 'points') and isinstance(results.points, list):
points_list = results.points
else:
points_list = []
hits = []
for hit in points_list:
# Qdrant >=1.7 returns ScoredPoint objects
if hasattr(hit, 'id') and hasattr(hit, 'payload'):
payload = hit.payload if isinstance(hit.payload, dict) else {}
hits.append({"id": hit.id, "document": payload.get('document', '')})
# Defensive: handle tuple (id, payload) fallback
elif isinstance(hit, tuple) and len(hit) >= 2:
_id = hit[0]
_payload = hit[1] if isinstance(hit[1], dict) else {}
hits.append({"id": _id, "document": _payload.get('document', '')})
# After reranking, take top_l
hits = hits[:top_l]
# Optionally filter by fuzzy/substring match if doc_search is True
if doc_search:
# Fuzzy: match any word in query_text (case-insensitive, partial match)
query_words = [w for w in re.split(r'\W+', query_text.lower()) if w]
def fuzzy_match(doc):
doc_l = doc.lower()
return any(qw in doc_l for qw in query_words)
filtered_hits = [hit for hit in hits if fuzzy_match(hit['document'])]
# Now also get all docs in the collection that match the substring/fuzzy (outside top_l reranked)
doc_hits = []
next_offset = None
while True:
scroll_result = self.client.scroll(collection_name=self.collection_name, with_payload=True, offset=next_offset)
points = scroll_result[0]
next_offset = scroll_result[1]
for point in points:
doc = point.payload.get('document', '') if hasattr(point, 'payload') else ''
if fuzzy_match(doc):
doc_hits.append({"id": point.id, "document": doc})
if not next_offset:
break
# Merge and deduplicate by id, prioritizing reranked hits
seen_ids = set()
merged = []
for hit in filtered_hits:
if hit['id'] not in seen_ids:
merged.append(hit)
seen_ids.add(hit['id'])
for hit in doc_hits:
if hit['id'] not in seen_ids:
merged.append(hit)
seen_ids.add(hit['id'])
return [f"{hit['document']}" for hit in merged] if merged else []
else:
return [f"{hit['document']}" for hit in hits] if hits else []
def save(self):
pass # Qdrant persists automatically
@classmethod
def load_database(cls, collection_prefix: str = "longterm_db", **kwargs):
return cls(collection_prefix=collection_prefix, **kwargs)
if __name__ == "__main__":
db = LongTermDatabase()
paths = [
r"C:\Users\dedeep vasireddy\Downloads\file3_literary_board_objectwise.json",
r"C:\Users\dedeep vasireddy\Downloads\file2_cultural_board_objectwise.json",
r"C:\Users\dedeep vasireddy\Downloads\file5_overall_coordinators_objectwise.json",
r"C:\Users\dedeep vasireddy\Downloads\file4_scitech_board_objectwise.json",
r"C:\Users\dedeep vasireddy\Downloads\file1_metadata_student_council_objectwise_full.json"
]
# for path in paths:
# db.add_data(path)
# db.add_data(paths[0]) # For testing, just add the first file
print("Total points in DB:", db.client.count(collection_name=db.collection_name).count)
query = input("Query: ")
res = db.smart_query(query, topk=10, top_l=5, use_late=True)
print("Final Results:", res)