MinaNasser's picture
1st
1bc3f18
from typing import List, Dict, Any, Literal, Optional, TypedDict
from qdrant_client import QdrantClient
from qdrant_client.models import (VectorParams,Distance,PointStruct,Filter,
FieldCondition,MatchValue,PointIdsList,MatchText,MatchAny)
import uuid
MatchType = Literal["eq", "text", "in"]
class MetaFilter(TypedDict):
field: str # metadata key
op: MatchType # eq | text | in
value: Any
clause: Literal["must", "should", "must_not"]
# filters = [
# {"field": "source", "op": "eq", "value": "file.pdf", "clause": "must"},
# {"field": "course", "op": "in", "value": ["math", "cs"], "clause": "should"},
# {"field": "bookmark_path", "op": "text", "value": "chapter1", "clause": "must"},
# ]
class QdrantStore:
def __init__(self, client: QdrantClient, collection_name: str, vector_size: int):
self.client = client
self.collection_name = collection_name
self.vector_size = vector_size
self.init_collection()
def init_collection(self):
existing = [c.name for c in self.client.get_collections().collections]
if self.collection_name in existing:
print(f"[INFO] Collection '{self.collection_name}' exists. ")
else:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE)
)
print(f"[INFO] Created collection '{self.collection_name}' with vector size {self.vector_size}")
def upsert_embeddings(
self,
client: QdrantClient,
collection: str,
embeddings: List[List[float]],
payloads: List[Dict[str, Any]],
batch_size: int = 64,
):
total = len(embeddings)
for i in range(0, total, batch_size):
batch_embs = embeddings[i:i + batch_size]
batch_payloads = payloads[i:i + batch_size]
points = [
PointStruct(
id=str(uuid.uuid4()),
vector=emb,
payload=payload
)
for emb, payload in zip(batch_embs, batch_payloads)
if emb is not None
]
if points:
self.client.upsert(
collection_name=self.collection_name,
points=points
)
print(f"[INFO] Inserted batch {i//batch_size + 1} ({len(points)} vectors)")
# def upsert_embeddings(
# self,
# client: QdrantClient,
# collection: str,
# embeddings: List[List[float]],
# payloads: List[Dict[str, Any]],
# batch_size: int = 128,
# ):
# total = len(embeddings)
# for i in range(0, total, batch_size):
# batch_embs = embeddings[i:i + batch_size]
# batch_payloads = payloads[i:i + batch_size]
# points = []
# for emb, payload in zip(batch_embs, batch_payloads):
# if emb is None:
# continue
# points.append(
# PointStruct(
# id=str(uuid.uuid4()),
# vector=emb,
# payload=payload
# )
# )
# if points:
# client.upsert(
# collection_name=collection,
# points=points
# )
# print(
# f"[INFO] Inserted batch {i//batch_size + 1} "
# f"({len(points)} vectors)"
# )
def delete_by_id(self,client: QdrantClient, collection: str, point_id: str):
try:
point_id_int = int(point_id)
client.delete(
collection_name=collection,
points_selector=PointIdsList(points=[point_id_int])
)
print(f"[INFO] Deleted point ID: {point_id}")
except Exception as exc:
print(f"[ERROR] Failed to delete point {point_id}: {exc}")
def build_qdrant_filter(self,filters: list[MetaFilter] | None) -> Filter | None:
if not filters:
return None
must, should, must_not = [], [], []
for f in filters:
key = f"metadata.{f['field']}"
op = f["op"]
value = f["value"]
if op == "eq":
cond = FieldCondition(key=key, match=MatchValue(value=value))
elif op == "text":
cond = FieldCondition(key=key, match=MatchText(text=value))
elif op == "in":
cond = FieldCondition(key=key, match=MatchAny(any=value))
else:
raise ValueError(f"Unsupported op: {op}")
if f["clause"] == "must":
must.append(cond)
elif f["clause"] == "should":
should.append(cond)
elif f["clause"] == "must_not":
must_not.append(cond)
return Filter(
must=must or None,
should=should or None,
must_not=must_not or None,
)
def query_qdrant(
self,
filters: list[MetaFilter] | None = None,
embedding: List[float] | None = None,
top_k: int = 5,
):
query_filter = self.build_qdrant_filter(filters)
try:
if embedding is not None:
response = self.client.query_points(
collection_name=self.collection_name,
query=embedding,
query_filter=query_filter,
limit=top_k,
with_payload=True,
)
points = response.points
with_score = True
else:
points, _ = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=query_filter,
limit=top_k,
with_payload=True,
)
with_score = False
return [
{
"id": p.id,
"score": p.score if with_score else None,
"content": p.payload.get("content"),
"metadata": p.payload.get("metadata"),
}
for p in points
]
except Exception as e:
print(f"[ERROR] Qdrant query failed: {e}")
return []
def get_all_documents(self):
try:
points, _ = self.client.scroll(
collection_name=self.collection_name,
limit=10000, # Adjust as needed
with_payload=True
)
return [
{
"id": p.id,
"content": p.payload.get("content"),
"metadata": p.payload.get("metadata"),
}
for p in points
]
except Exception as e:
print(f"[ERROR] Failed to retrieve all documents: {e}")
return []
def get_all_files(self):
try:
points, _ = self.client.scroll(
collection_name=self.collection_name,
limit=10000, # Adjust as needed
with_payload=True
)
files_usernames_courses = set()
for p in points:
metadata = p.payload.get("metadata", {})
source = metadata.get("source")
username = metadata.get("username")
course = metadata.get("course")
if source and username and course:
files_usernames_courses.add((source, username, course))
return list(files_usernames_courses)
except Exception as e:
print(f"[ERROR] Failed to retrieve all files: {e}")
return []
def remove_collection(self):
try:
self.client.delete_collection(collection_name=self.collection_name)
print(f"[INFO] Collection '{self.collection_name}' deleted.")
except Exception as e:
print(f"[ERROR] Failed to delete collection: {e}")
def list_collections(self):
try:
collections = self.client.get_collections().collections
return [c.name for c in collections]
except Exception as e:
print(f"[ERROR] Failed to list collections: {e}")
return []
def remove_points_by_file(self, source_file: str,username: str ,course: str):
try:
response, _ = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=Filter(
must=[
FieldCondition(
key="metadata.source",
match=MatchValue(value=source_file)
),
FieldCondition(
key="metadata.username",
match=MatchValue(value=username)
),
FieldCondition(
key="metadata.course",
match=MatchValue(value=course)
)
]
),
limit=10000, # Adjust as needed
with_payload=False
)
point_ids = [p.id for p in response]
print(f"[INFO] Found {len(point_ids)} points for file '{source_file}' to delete.")
if point_ids:
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(points=point_ids)
)
print(f"[INFO] Deleted {len(point_ids)} points for file '{source_file}'")
return True
else:
print(f"[INFO] No points found for file '{source_file}' to delete.")
return False
except Exception as e:
print(f"[ERROR] Failed to delete points for file '{source_file}': {e}")
return False
def all_user_files_bookmarks(self, username: str):
try:
raw: dict[str, list[list[str]]] = {}
next_offset = None
while True:
response, next_offset = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=Filter(
must=[
FieldCondition(
key="metadata.username",
match=MatchValue(value=username)
)
]
),
limit=100,
offset=next_offset,
with_payload=True,
with_vectors=False
)
for p in response:
metadata = p.payload.get("metadata", {})
source = metadata.get("source")
bookmark_path = metadata.get("bookmark_path") # list like ["Part", "Chapter", "Section"]
if not source or not isinstance(bookmark_path, list):
continue
if source not in raw:
raw[source] = []
if bookmark_path not in raw[source]:
raw[source].append(bookmark_path)
if next_offset is None:
break
# Build nested dict: source → part → chapter → [sections]
result = {}
for source, paths in raw.items():
nested = {}
for path in paths:
if len(path) == 0:
continue
part = path[0]
chapter = path[1] if len(path) > 1 else None
section = path[2] if len(path) > 2 else None
nested.setdefault(part, {})
if chapter is None:
# top-level bookmark (e.g. ["Preface"])
nested[part].setdefault("_sections", [])
continue
nested[part].setdefault(chapter, [])
if section and section not in nested[part][chapter]:
nested[part][chapter].append(section)
result[source] = nested
print(f"[INFO] Retrieved grouped bookmarks for user '{username}': {result}")
return result
except Exception as e:
print(f"[ERROR] Failed to retrieve user files and bookmarks: {e}")
return {}
def retrieve_chunks_by_topic(self,username: str,course: str,topic_embeddings,
refernces: Optional[List[dict]] = None,chunks_per_topic: int = 5):
bookmarked_only = False
metadata_filter = [
{"field": "username", "op": "eq", "value": username, "clause": "must"},
{"field": "course", "op": "eq", "value": course, "clause": "must"},
]
results = []
if refernces:
for ref in refernces:
metadata_filter.append({"field": "source", "op": "eq", "value": ref.filename, "clause": "must"})
bookmarks=ref.bookmarks if ref.bookmarks else []
#print(bookmarks)
if bookmarks == []:
ten=self.query_qdrant(
filters=metadata_filter,
embedding=topic_embeddings,
top_k=chunks_per_topic)
for one in ten:
results.append(one)
else:
bookmarked_only = True
bookmarks_length = len(bookmarks)
for bookmark in bookmarks:
metadata_filter.append({"field": "bookmark_path", "op": "text", "value": bookmark, "clause": "must"})
ten=self.query_qdrant(
filters=metadata_filter,
embedding=topic_embeddings,
top_k=chunks_per_topic//bookmarks_length
)
for one in ten:
results.append(one)
metadata_filter.pop() # remove bookmark filter
metadata_filter.pop() # remove source filter
if not refernces:
ten=self.query_qdrant(
filters=metadata_filter,
embedding=topic_embeddings,
top_k=chunks_per_topic)
for one in ten:
results.append(one)
if bookmarked_only:
results = [r for r in results if r.get("metadata", {}).get("bookmark_path")]
else:
bookmarked = [r for r in results if r.get("metadata", {}).get("bookmark_path")]
non_bookmarked = [r for r in results if not r.get("metadata", {}).get("bookmark_path")]
results = []
while len(results) < chunks_per_topic and (bookmarked or non_bookmarked):
if bookmarked:
results.append(bookmarked.pop(0))
if non_bookmarked and len(results) < chunks_per_topic:
results.append(non_bookmarked.pop(0))
results = results[:chunks_per_topic]
return results[:chunks_per_topic]
def retrieve_for_exam(self,topics: List,username: str,course: str = None,
references: Optional[List[dict]] = None,chunks_per_topic: int = 5):
exam_chunks = {}
for topic in topics:
chunks = self.retrieve_chunks_by_topic(
username=username,
course=course,
topic_embeddings=topic[1], # topic[0] = str topic [1] = embeddings
refernces=references,
chunks_per_topic=chunks_per_topic
)
#print(chunks)
exam_chunks[topic[0]] = chunks
return exam_chunks