Spaces:
Paused
Paused
| from typing import List, Dict, Any, Literal, Optional, TypedDict | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import (VectorParams,Distance,PointStruct,Filter, | |
| FieldCondition,MatchValue,PointIdsList,MatchText,MatchAny) | |
| import uuid | |
| MatchType = Literal["eq", "text", "in"] | |
| class MetaFilter(TypedDict): | |
| field: str # metadata key | |
| op: MatchType # eq | text | in | |
| value: Any | |
| clause: Literal["must", "should", "must_not"] | |
| # filters = [ | |
| # {"field": "source", "op": "eq", "value": "file.pdf", "clause": "must"}, | |
| # {"field": "course", "op": "in", "value": ["math", "cs"], "clause": "should"}, | |
| # {"field": "bookmark_path", "op": "text", "value": "chapter1", "clause": "must"}, | |
| # ] | |
| class QdrantStore: | |
| def __init__(self, client: QdrantClient, collection_name: str, vector_size: int): | |
| self.client = client | |
| self.collection_name = collection_name | |
| self.vector_size = vector_size | |
| self.init_collection() | |
| def init_collection(self): | |
| existing = [c.name for c in self.client.get_collections().collections] | |
| if self.collection_name in existing: | |
| print(f"[INFO] Collection '{self.collection_name}' exists. ") | |
| else: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) | |
| ) | |
| print(f"[INFO] Created collection '{self.collection_name}' with vector size {self.vector_size}") | |
| def upsert_embeddings( | |
| self, | |
| client: QdrantClient, | |
| collection: str, | |
| embeddings: List[List[float]], | |
| payloads: List[Dict[str, Any]], | |
| batch_size: int = 64, | |
| ): | |
| total = len(embeddings) | |
| for i in range(0, total, batch_size): | |
| batch_embs = embeddings[i:i + batch_size] | |
| batch_payloads = payloads[i:i + batch_size] | |
| points = [ | |
| PointStruct( | |
| id=str(uuid.uuid4()), | |
| vector=emb, | |
| payload=payload | |
| ) | |
| for emb, payload in zip(batch_embs, batch_payloads) | |
| if emb is not None | |
| ] | |
| if points: | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=points | |
| ) | |
| print(f"[INFO] Inserted batch {i//batch_size + 1} ({len(points)} vectors)") | |
| # def upsert_embeddings( | |
| # self, | |
| # client: QdrantClient, | |
| # collection: str, | |
| # embeddings: List[List[float]], | |
| # payloads: List[Dict[str, Any]], | |
| # batch_size: int = 128, | |
| # ): | |
| # total = len(embeddings) | |
| # for i in range(0, total, batch_size): | |
| # batch_embs = embeddings[i:i + batch_size] | |
| # batch_payloads = payloads[i:i + batch_size] | |
| # points = [] | |
| # for emb, payload in zip(batch_embs, batch_payloads): | |
| # if emb is None: | |
| # continue | |
| # points.append( | |
| # PointStruct( | |
| # id=str(uuid.uuid4()), | |
| # vector=emb, | |
| # payload=payload | |
| # ) | |
| # ) | |
| # if points: | |
| # client.upsert( | |
| # collection_name=collection, | |
| # points=points | |
| # ) | |
| # print( | |
| # f"[INFO] Inserted batch {i//batch_size + 1} " | |
| # f"({len(points)} vectors)" | |
| # ) | |
| def delete_by_id(self,client: QdrantClient, collection: str, point_id: str): | |
| try: | |
| point_id_int = int(point_id) | |
| client.delete( | |
| collection_name=collection, | |
| points_selector=PointIdsList(points=[point_id_int]) | |
| ) | |
| print(f"[INFO] Deleted point ID: {point_id}") | |
| except Exception as exc: | |
| print(f"[ERROR] Failed to delete point {point_id}: {exc}") | |
| def build_qdrant_filter(self,filters: list[MetaFilter] | None) -> Filter | None: | |
| if not filters: | |
| return None | |
| must, should, must_not = [], [], [] | |
| for f in filters: | |
| key = f"metadata.{f['field']}" | |
| op = f["op"] | |
| value = f["value"] | |
| if op == "eq": | |
| cond = FieldCondition(key=key, match=MatchValue(value=value)) | |
| elif op == "text": | |
| cond = FieldCondition(key=key, match=MatchText(text=value)) | |
| elif op == "in": | |
| cond = FieldCondition(key=key, match=MatchAny(any=value)) | |
| else: | |
| raise ValueError(f"Unsupported op: {op}") | |
| if f["clause"] == "must": | |
| must.append(cond) | |
| elif f["clause"] == "should": | |
| should.append(cond) | |
| elif f["clause"] == "must_not": | |
| must_not.append(cond) | |
| return Filter( | |
| must=must or None, | |
| should=should or None, | |
| must_not=must_not or None, | |
| ) | |
| def query_qdrant( | |
| self, | |
| filters: list[MetaFilter] | None = None, | |
| embedding: List[float] | None = None, | |
| top_k: int = 5, | |
| ): | |
| query_filter = self.build_qdrant_filter(filters) | |
| try: | |
| if embedding is not None: | |
| response = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=embedding, | |
| query_filter=query_filter, | |
| limit=top_k, | |
| with_payload=True, | |
| ) | |
| points = response.points | |
| with_score = True | |
| else: | |
| points, _ = self.client.scroll( | |
| collection_name=self.collection_name, | |
| scroll_filter=query_filter, | |
| limit=top_k, | |
| with_payload=True, | |
| ) | |
| with_score = False | |
| return [ | |
| { | |
| "id": p.id, | |
| "score": p.score if with_score else None, | |
| "content": p.payload.get("content"), | |
| "metadata": p.payload.get("metadata"), | |
| } | |
| for p in points | |
| ] | |
| except Exception as e: | |
| print(f"[ERROR] Qdrant query failed: {e}") | |
| return [] | |
| def get_all_documents(self): | |
| try: | |
| points, _ = self.client.scroll( | |
| collection_name=self.collection_name, | |
| limit=10000, # Adjust as needed | |
| with_payload=True | |
| ) | |
| return [ | |
| { | |
| "id": p.id, | |
| "content": p.payload.get("content"), | |
| "metadata": p.payload.get("metadata"), | |
| } | |
| for p in points | |
| ] | |
| except Exception as e: | |
| print(f"[ERROR] Failed to retrieve all documents: {e}") | |
| return [] | |
| def get_all_files(self): | |
| try: | |
| points, _ = self.client.scroll( | |
| collection_name=self.collection_name, | |
| limit=10000, # Adjust as needed | |
| with_payload=True | |
| ) | |
| files_usernames_courses = set() | |
| for p in points: | |
| metadata = p.payload.get("metadata", {}) | |
| source = metadata.get("source") | |
| username = metadata.get("username") | |
| course = metadata.get("course") | |
| if source and username and course: | |
| files_usernames_courses.add((source, username, course)) | |
| return list(files_usernames_courses) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to retrieve all files: {e}") | |
| return [] | |
| def remove_collection(self): | |
| try: | |
| self.client.delete_collection(collection_name=self.collection_name) | |
| print(f"[INFO] Collection '{self.collection_name}' deleted.") | |
| except Exception as e: | |
| print(f"[ERROR] Failed to delete collection: {e}") | |
| def list_collections(self): | |
| try: | |
| collections = self.client.get_collections().collections | |
| return [c.name for c in collections] | |
| except Exception as e: | |
| print(f"[ERROR] Failed to list collections: {e}") | |
| return [] | |
| def remove_points_by_file(self, source_file: str,username: str ,course: str): | |
| try: | |
| response, _ = self.client.scroll( | |
| collection_name=self.collection_name, | |
| scroll_filter=Filter( | |
| must=[ | |
| FieldCondition( | |
| key="metadata.source", | |
| match=MatchValue(value=source_file) | |
| ), | |
| FieldCondition( | |
| key="metadata.username", | |
| match=MatchValue(value=username) | |
| ), | |
| FieldCondition( | |
| key="metadata.course", | |
| match=MatchValue(value=course) | |
| ) | |
| ] | |
| ), | |
| limit=10000, # Adjust as needed | |
| with_payload=False | |
| ) | |
| point_ids = [p.id for p in response] | |
| print(f"[INFO] Found {len(point_ids)} points for file '{source_file}' to delete.") | |
| if point_ids: | |
| self.client.delete( | |
| collection_name=self.collection_name, | |
| points_selector=PointIdsList(points=point_ids) | |
| ) | |
| print(f"[INFO] Deleted {len(point_ids)} points for file '{source_file}'") | |
| return True | |
| else: | |
| print(f"[INFO] No points found for file '{source_file}' to delete.") | |
| return False | |
| except Exception as e: | |
| print(f"[ERROR] Failed to delete points for file '{source_file}': {e}") | |
| return False | |
| def all_user_files_bookmarks(self, username: str): | |
| try: | |
| raw: dict[str, list[list[str]]] = {} | |
| next_offset = None | |
| while True: | |
| response, next_offset = self.client.scroll( | |
| collection_name=self.collection_name, | |
| scroll_filter=Filter( | |
| must=[ | |
| FieldCondition( | |
| key="metadata.username", | |
| match=MatchValue(value=username) | |
| ) | |
| ] | |
| ), | |
| limit=100, | |
| offset=next_offset, | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| for p in response: | |
| metadata = p.payload.get("metadata", {}) | |
| source = metadata.get("source") | |
| bookmark_path = metadata.get("bookmark_path") # list like ["Part", "Chapter", "Section"] | |
| if not source or not isinstance(bookmark_path, list): | |
| continue | |
| if source not in raw: | |
| raw[source] = [] | |
| if bookmark_path not in raw[source]: | |
| raw[source].append(bookmark_path) | |
| if next_offset is None: | |
| break | |
| # Build nested dict: source → part → chapter → [sections] | |
| result = {} | |
| for source, paths in raw.items(): | |
| nested = {} | |
| for path in paths: | |
| if len(path) == 0: | |
| continue | |
| part = path[0] | |
| chapter = path[1] if len(path) > 1 else None | |
| section = path[2] if len(path) > 2 else None | |
| nested.setdefault(part, {}) | |
| if chapter is None: | |
| # top-level bookmark (e.g. ["Preface"]) | |
| nested[part].setdefault("_sections", []) | |
| continue | |
| nested[part].setdefault(chapter, []) | |
| if section and section not in nested[part][chapter]: | |
| nested[part][chapter].append(section) | |
| result[source] = nested | |
| print(f"[INFO] Retrieved grouped bookmarks for user '{username}': {result}") | |
| return result | |
| except Exception as e: | |
| print(f"[ERROR] Failed to retrieve user files and bookmarks: {e}") | |
| return {} | |
| def retrieve_chunks_by_topic(self,username: str,course: str,topic_embeddings, | |
| refernces: Optional[List[dict]] = None,chunks_per_topic: int = 5): | |
| bookmarked_only = False | |
| metadata_filter = [ | |
| {"field": "username", "op": "eq", "value": username, "clause": "must"}, | |
| {"field": "course", "op": "eq", "value": course, "clause": "must"}, | |
| ] | |
| results = [] | |
| if refernces: | |
| for ref in refernces: | |
| metadata_filter.append({"field": "source", "op": "eq", "value": ref.filename, "clause": "must"}) | |
| bookmarks=ref.bookmarks if ref.bookmarks else [] | |
| #print(bookmarks) | |
| if bookmarks == []: | |
| ten=self.query_qdrant( | |
| filters=metadata_filter, | |
| embedding=topic_embeddings, | |
| top_k=chunks_per_topic) | |
| for one in ten: | |
| results.append(one) | |
| else: | |
| bookmarked_only = True | |
| bookmarks_length = len(bookmarks) | |
| for bookmark in bookmarks: | |
| metadata_filter.append({"field": "bookmark_path", "op": "text", "value": bookmark, "clause": "must"}) | |
| ten=self.query_qdrant( | |
| filters=metadata_filter, | |
| embedding=topic_embeddings, | |
| top_k=chunks_per_topic//bookmarks_length | |
| ) | |
| for one in ten: | |
| results.append(one) | |
| metadata_filter.pop() # remove bookmark filter | |
| metadata_filter.pop() # remove source filter | |
| if not refernces: | |
| ten=self.query_qdrant( | |
| filters=metadata_filter, | |
| embedding=topic_embeddings, | |
| top_k=chunks_per_topic) | |
| for one in ten: | |
| results.append(one) | |
| if bookmarked_only: | |
| results = [r for r in results if r.get("metadata", {}).get("bookmark_path")] | |
| else: | |
| bookmarked = [r for r in results if r.get("metadata", {}).get("bookmark_path")] | |
| non_bookmarked = [r for r in results if not r.get("metadata", {}).get("bookmark_path")] | |
| results = [] | |
| while len(results) < chunks_per_topic and (bookmarked or non_bookmarked): | |
| if bookmarked: | |
| results.append(bookmarked.pop(0)) | |
| if non_bookmarked and len(results) < chunks_per_topic: | |
| results.append(non_bookmarked.pop(0)) | |
| results = results[:chunks_per_topic] | |
| return results[:chunks_per_topic] | |
| def retrieve_for_exam(self,topics: List,username: str,course: str = None, | |
| references: Optional[List[dict]] = None,chunks_per_topic: int = 5): | |
| exam_chunks = {} | |
| for topic in topics: | |
| chunks = self.retrieve_chunks_by_topic( | |
| username=username, | |
| course=course, | |
| topic_embeddings=topic[1], # topic[0] = str topic [1] = embeddings | |
| refernces=references, | |
| chunks_per_topic=chunks_per_topic | |
| ) | |
| #print(chunks) | |
| exam_chunks[topic[0]] = chunks | |
| return exam_chunks | |