from typing import List, Dict, Any, Literal, Optional, TypedDict from qdrant_client import QdrantClient from qdrant_client.models import (VectorParams,Distance,PointStruct,Filter, FieldCondition,MatchValue,PointIdsList,MatchText,MatchAny) import uuid MatchType = Literal["eq", "text", "in"] class MetaFilter(TypedDict): field: str # metadata key op: MatchType # eq | text | in value: Any clause: Literal["must", "should", "must_not"] # filters = [ # {"field": "source", "op": "eq", "value": "file.pdf", "clause": "must"}, # {"field": "course", "op": "in", "value": ["math", "cs"], "clause": "should"}, # {"field": "bookmark_path", "op": "text", "value": "chapter1", "clause": "must"}, # ] class QdrantStore: def __init__(self, client: QdrantClient, collection_name: str, vector_size: int): self.client = client self.collection_name = collection_name self.vector_size = vector_size self.init_collection() def init_collection(self): existing = [c.name for c in self.client.get_collections().collections] if self.collection_name in existing: print(f"[INFO] Collection '{self.collection_name}' exists. ") else: self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) ) print(f"[INFO] Created collection '{self.collection_name}' with vector size {self.vector_size}") def upsert_embeddings( self, client: QdrantClient, collection: str, embeddings: List[List[float]], payloads: List[Dict[str, Any]], batch_size: int = 64, ): total = len(embeddings) for i in range(0, total, batch_size): batch_embs = embeddings[i:i + batch_size] batch_payloads = payloads[i:i + batch_size] points = [ PointStruct( id=str(uuid.uuid4()), vector=emb, payload=payload ) for emb, payload in zip(batch_embs, batch_payloads) if emb is not None ] if points: self.client.upsert( collection_name=self.collection_name, points=points ) print(f"[INFO] Inserted batch {i//batch_size + 1} ({len(points)} vectors)") # def upsert_embeddings( # self, # client: QdrantClient, # collection: str, # embeddings: List[List[float]], # payloads: List[Dict[str, Any]], # batch_size: int = 128, # ): # total = len(embeddings) # for i in range(0, total, batch_size): # batch_embs = embeddings[i:i + batch_size] # batch_payloads = payloads[i:i + batch_size] # points = [] # for emb, payload in zip(batch_embs, batch_payloads): # if emb is None: # continue # points.append( # PointStruct( # id=str(uuid.uuid4()), # vector=emb, # payload=payload # ) # ) # if points: # client.upsert( # collection_name=collection, # points=points # ) # print( # f"[INFO] Inserted batch {i//batch_size + 1} " # f"({len(points)} vectors)" # ) def delete_by_id(self,client: QdrantClient, collection: str, point_id: str): try: point_id_int = int(point_id) client.delete( collection_name=collection, points_selector=PointIdsList(points=[point_id_int]) ) print(f"[INFO] Deleted point ID: {point_id}") except Exception as exc: print(f"[ERROR] Failed to delete point {point_id}: {exc}") def build_qdrant_filter(self,filters: list[MetaFilter] | None) -> Filter | None: if not filters: return None must, should, must_not = [], [], [] for f in filters: key = f"metadata.{f['field']}" op = f["op"] value = f["value"] if op == "eq": cond = FieldCondition(key=key, match=MatchValue(value=value)) elif op == "text": cond = FieldCondition(key=key, match=MatchText(text=value)) elif op == "in": cond = FieldCondition(key=key, match=MatchAny(any=value)) else: raise ValueError(f"Unsupported op: {op}") if f["clause"] == "must": must.append(cond) elif f["clause"] == "should": should.append(cond) elif f["clause"] == "must_not": must_not.append(cond) return Filter( must=must or None, should=should or None, must_not=must_not or None, ) def query_qdrant( self, filters: list[MetaFilter] | None = None, embedding: List[float] | None = None, top_k: int = 5, ): query_filter = self.build_qdrant_filter(filters) try: if embedding is not None: response = self.client.query_points( collection_name=self.collection_name, query=embedding, query_filter=query_filter, limit=top_k, with_payload=True, ) points = response.points with_score = True else: points, _ = self.client.scroll( collection_name=self.collection_name, scroll_filter=query_filter, limit=top_k, with_payload=True, ) with_score = False return [ { "id": p.id, "score": p.score if with_score else None, "content": p.payload.get("content"), "metadata": p.payload.get("metadata"), } for p in points ] except Exception as e: print(f"[ERROR] Qdrant query failed: {e}") return [] def get_all_documents(self): try: points, _ = self.client.scroll( collection_name=self.collection_name, limit=10000, # Adjust as needed with_payload=True ) return [ { "id": p.id, "content": p.payload.get("content"), "metadata": p.payload.get("metadata"), } for p in points ] except Exception as e: print(f"[ERROR] Failed to retrieve all documents: {e}") return [] def get_all_files(self): try: points, _ = self.client.scroll( collection_name=self.collection_name, limit=10000, # Adjust as needed with_payload=True ) files_usernames_courses = set() for p in points: metadata = p.payload.get("metadata", {}) source = metadata.get("source") username = metadata.get("username") course = metadata.get("course") if source and username and course: files_usernames_courses.add((source, username, course)) return list(files_usernames_courses) except Exception as e: print(f"[ERROR] Failed to retrieve all files: {e}") return [] def remove_collection(self): try: self.client.delete_collection(collection_name=self.collection_name) print(f"[INFO] Collection '{self.collection_name}' deleted.") except Exception as e: print(f"[ERROR] Failed to delete collection: {e}") def list_collections(self): try: collections = self.client.get_collections().collections return [c.name for c in collections] except Exception as e: print(f"[ERROR] Failed to list collections: {e}") return [] def remove_points_by_file(self, source_file: str,username: str ,course: str): try: response, _ = self.client.scroll( collection_name=self.collection_name, scroll_filter=Filter( must=[ FieldCondition( key="metadata.source", match=MatchValue(value=source_file) ), FieldCondition( key="metadata.username", match=MatchValue(value=username) ), FieldCondition( key="metadata.course", match=MatchValue(value=course) ) ] ), limit=10000, # Adjust as needed with_payload=False ) point_ids = [p.id for p in response] print(f"[INFO] Found {len(point_ids)} points for file '{source_file}' to delete.") if point_ids: self.client.delete( collection_name=self.collection_name, points_selector=PointIdsList(points=point_ids) ) print(f"[INFO] Deleted {len(point_ids)} points for file '{source_file}'") return True else: print(f"[INFO] No points found for file '{source_file}' to delete.") return False except Exception as e: print(f"[ERROR] Failed to delete points for file '{source_file}': {e}") return False def all_user_files_bookmarks(self, username: str): try: raw: dict[str, list[list[str]]] = {} next_offset = None while True: response, next_offset = self.client.scroll( collection_name=self.collection_name, scroll_filter=Filter( must=[ FieldCondition( key="metadata.username", match=MatchValue(value=username) ) ] ), limit=100, offset=next_offset, with_payload=True, with_vectors=False ) for p in response: metadata = p.payload.get("metadata", {}) source = metadata.get("source") bookmark_path = metadata.get("bookmark_path") # list like ["Part", "Chapter", "Section"] if not source or not isinstance(bookmark_path, list): continue if source not in raw: raw[source] = [] if bookmark_path not in raw[source]: raw[source].append(bookmark_path) if next_offset is None: break # Build nested dict: source → part → chapter → [sections] result = {} for source, paths in raw.items(): nested = {} for path in paths: if len(path) == 0: continue part = path[0] chapter = path[1] if len(path) > 1 else None section = path[2] if len(path) > 2 else None nested.setdefault(part, {}) if chapter is None: # top-level bookmark (e.g. ["Preface"]) nested[part].setdefault("_sections", []) continue nested[part].setdefault(chapter, []) if section and section not in nested[part][chapter]: nested[part][chapter].append(section) result[source] = nested print(f"[INFO] Retrieved grouped bookmarks for user '{username}': {result}") return result except Exception as e: print(f"[ERROR] Failed to retrieve user files and bookmarks: {e}") return {} def retrieve_chunks_by_topic(self,username: str,course: str,topic_embeddings, refernces: Optional[List[dict]] = None,chunks_per_topic: int = 5): bookmarked_only = False metadata_filter = [ {"field": "username", "op": "eq", "value": username, "clause": "must"}, {"field": "course", "op": "eq", "value": course, "clause": "must"}, ] results = [] if refernces: for ref in refernces: metadata_filter.append({"field": "source", "op": "eq", "value": ref.filename, "clause": "must"}) bookmarks=ref.bookmarks if ref.bookmarks else [] #print(bookmarks) if bookmarks == []: ten=self.query_qdrant( filters=metadata_filter, embedding=topic_embeddings, top_k=chunks_per_topic) for one in ten: results.append(one) else: bookmarked_only = True bookmarks_length = len(bookmarks) for bookmark in bookmarks: metadata_filter.append({"field": "bookmark_path", "op": "text", "value": bookmark, "clause": "must"}) ten=self.query_qdrant( filters=metadata_filter, embedding=topic_embeddings, top_k=chunks_per_topic//bookmarks_length ) for one in ten: results.append(one) metadata_filter.pop() # remove bookmark filter metadata_filter.pop() # remove source filter if not refernces: ten=self.query_qdrant( filters=metadata_filter, embedding=topic_embeddings, top_k=chunks_per_topic) for one in ten: results.append(one) if bookmarked_only: results = [r for r in results if r.get("metadata", {}).get("bookmark_path")] else: bookmarked = [r for r in results if r.get("metadata", {}).get("bookmark_path")] non_bookmarked = [r for r in results if not r.get("metadata", {}).get("bookmark_path")] results = [] while len(results) < chunks_per_topic and (bookmarked or non_bookmarked): if bookmarked: results.append(bookmarked.pop(0)) if non_bookmarked and len(results) < chunks_per_topic: results.append(non_bookmarked.pop(0)) results = results[:chunks_per_topic] return results[:chunks_per_topic] def retrieve_for_exam(self,topics: List,username: str,course: str = None, references: Optional[List[dict]] = None,chunks_per_topic: int = 5): exam_chunks = {} for topic in topics: chunks = self.retrieve_chunks_by_topic( username=username, course=course, topic_embeddings=topic[1], # topic[0] = str topic [1] = embeddings refernces=references, chunks_per_topic=chunks_per_topic ) #print(chunks) exam_chunks[topic[0]] = chunks return exam_chunks