from __future__ import annotations import gc import os from typing import Callable import torch from qdrant_client import QdrantClient from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.models import VectorParams, Distance from langchain_qdrant import QdrantVectorStore from langchain_huggingface import HuggingFaceEmbeddings # Import updated configuration variables from config import ( Config, DEVICE, EMBEDDING_MODEL, EMBED_BATCH_SIZE, QDRANT_MODE, QDRANT_HOST, QDRANT_PORT, QDRANT_URL, QDRANT_API_KEY, QDRANT_STORAGE_PATH, QDRANT_VECTOR_SIZE, ) print(f"🚀 easyResearch running on device: {DEVICE.upper()}") # Initialize Embedding Model embedding_model = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL, model_kwargs={"device": DEVICE}, encode_kwargs={"normalize_embeddings": True}, ) def get_qdrant_client() -> QdrantClient: """Initialize QdrantClient based on the configured mode.""" if QDRANT_MODE == "cloud": return QdrantClient(url=QDRANT_HOST, api_key=QDRANT_API_KEY) elif QDRANT_MODE == "local": return QdrantClient(path=QDRANT_STORAGE_PATH) else: return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) def check_qdrant_health() -> dict: """Check connection health and return basic info.""" try: client = get_qdrant_client() info = client.get_collections() target = QDRANT_STORAGE_PATH if QDRANT_MODE == "local" else f"{QDRANT_HOST}:{QDRANT_PORT}" return { "status": "ok", "mode": QDRANT_MODE, "target": target, "collections": [c.name for c in info.collections], } except Exception as e: return {"status": "error", "error": str(e)} def ensure_collection_exists(collection_name: str) -> None: """Create collection if it does not exist.""" client = get_qdrant_client() try: client.get_collection(collection_name) except (UnexpectedResponse, Exception): client.create_collection( collection_name=collection_name, vectors_config=VectorParams(size=QDRANT_VECTOR_SIZE, distance=Distance.COSINE), ) print(f"✨ Created collection: {collection_name}") def _initialize_vector_store(col_name: str) -> QdrantVectorStore: """Internal helper to create QdrantVectorStore instance based on mode.""" ensure_collection_exists(col_name) if QDRANT_MODE == "local": return QdrantVectorStore.from_existing_collection( embedding=embedding_model, collection_name=col_name, path=QDRANT_STORAGE_PATH, ) elif QDRANT_MODE == "cloud": return QdrantVectorStore.from_existing_collection( embedding=embedding_model, collection_name=col_name, url=QDRANT_HOST, api_key=QDRANT_API_KEY, ) else: return QdrantVectorStore.from_existing_collection( embedding=embedding_model, collection_name=col_name, url=f"http://{QDRANT_HOST}:{QDRANT_PORT}", ) def get_vector_store(collection_name: str) -> QdrantVectorStore: """Public method to get vector store for a specific workspace.""" col_name = Config.get_collection_name(collection_name) return _initialize_vector_store(col_name) def add_to_vector_db( chunks, collection_name: str, batch_size: int | None = None, progress_callback: Callable[[int, int], None] | None = None, ): """Embed and add document chunks to the vector database.""" col_name = Config.get_collection_name(collection_name) db = _initialize_vector_store(col_name) texts = [chunk.page_content for chunk in chunks] metadatas = [chunk.metadata for chunk in chunks] ids = [chunk.id for chunk in chunks] bs = batch_size or EMBED_BATCH_SIZE total = len(chunks) print(f"📥 Embedding {total} chunks into '{col_name}' (mode={QDRANT_MODE}, batch={bs}) …") for i in range(0, total, bs): end = min(i + bs, total) db.add_texts( texts=texts[i:end], metadatas=metadatas[i:end], ids=ids[i:end], ) # Memory management if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print(f" ✅ Batch {i} → {end}") if progress_callback: progress_callback(end, total) return db def get_retriever(collection_name: str, k: int = 5, fetch_k: int = 20): """Get a retriever object using MMR (Max Marginal Relevance).""" db = get_vector_store(collection_name) return db.as_retriever( search_type="mmr", search_kwargs={"k": k, "fetch_k": fetch_k}, ) def get_notebook_stats(notebook_name: str) -> dict: """Fetch statistics for a specific workspace collection.""" col_name = Config.get_collection_name(notebook_name) stats = {"chunks": 0, "files": [], "size_mb": 0.0} try: client = get_qdrant_client() info = client.get_collection(col_name) stats["chunks"] = info.points_count or 0 if stats["chunks"] > 0: points, _ = client.scroll( collection_name=col_name, limit=10000, with_payload=True, with_vectors=False, ) sources = { p.payload.get("metadata", {}).get("source", "") for p in points if p.payload } stats["files"] = sorted([s for s in sources if s]) if hasattr(info, "payload_schema"): # Check if size info is available # disk_data_size is not always available in all client versions/modes pass except (UnexpectedResponse, Exception): pass return stats def get_total_db_size() -> float: """Calculate total size of all collections (if supported by mode).""" try: client = get_qdrant_client() collections = client.get_collections() total = 0 for col in collections.collections: try: info = client.get_collection(col.name) total += getattr(info, 'disk_data_size', 0) or 0 except Exception: pass return round(total / (1024 * 1024), 2) except Exception: return 0.0 def get_all_notebooks() -> list[str]: """List all workspace names by scanning collection prefixes.""" try: client = get_qdrant_client() collections = client.get_collections() workspaces = [] for c in collections.collections: if c.name.startswith("ws_"): ws_name = c.name[3:] workspaces.append(ws_name) return workspaces except Exception as e: print(f"⚠️ List collections error: {e}") return [] def delete_file_from_notebook(notebook_name: str, source_name: str) -> int: """Delete all chunks belonging to a specific file within a workspace.""" col_name = Config.get_collection_name(notebook_name) try: client = get_qdrant_client() from qdrant_client.models import Filter, FieldCondition, MatchValue points, _ = client.scroll( collection_name=col_name, scroll_filter=Filter( must=[ FieldCondition( key="metadata.source", match=MatchValue(value=source_name), ) ] ), limit=10000, with_payload=False, with_vectors=False, ) if points: ids = [p.id for p in points] client.delete( collection_name=col_name, points_selector=ids, ) print(f"🗑️ Deleted {len(ids)} chunks of '{source_name}' from {col_name}") return len(ids) return 0 except Exception as e: print(f"❌ Delete file error: {e}") return 0 def delete_notebook(notebook_name: str) -> bool: """Delete an entire workspace collection.""" col_name = Config.get_collection_name(notebook_name) try: client = get_qdrant_client() client.delete_collection(col_name) print(f"🗑️ Deleted collection: {col_name}") return True except Exception as e: print(f"❌ Delete collection error: {e}") return False