easyResearchBigData / core /embedder.py
hzjanuary's picture
Update core/embedder.py
7ad64cc verified
from __future__ import annotations
import gc
import os
from typing import Callable
import torch
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.models import VectorParams, Distance
from langchain_qdrant import QdrantVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
# Import updated configuration variables
from config import (
Config,
DEVICE,
EMBEDDING_MODEL,
EMBED_BATCH_SIZE,
QDRANT_MODE,
QDRANT_HOST,
QDRANT_PORT,
QDRANT_URL,
QDRANT_API_KEY,
QDRANT_STORAGE_PATH,
QDRANT_VECTOR_SIZE,
)
print(f"πŸš€ easyResearch running on device: {DEVICE.upper()}")
# Initialize Embedding Model
embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": DEVICE},
encode_kwargs={"normalize_embeddings": True},
)
def get_qdrant_client() -> QdrantClient:
"""Initialize QdrantClient based on the configured mode."""
if QDRANT_MODE == "cloud":
return QdrantClient(url=QDRANT_HOST, api_key=QDRANT_API_KEY)
elif QDRANT_MODE == "local":
return QdrantClient(path=QDRANT_STORAGE_PATH)
else:
return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
def check_qdrant_health() -> dict:
"""Check connection health and return basic info."""
try:
client = get_qdrant_client()
info = client.get_collections()
target = QDRANT_STORAGE_PATH if QDRANT_MODE == "local" else f"{QDRANT_HOST}:{QDRANT_PORT}"
return {
"status": "ok",
"mode": QDRANT_MODE,
"target": target,
"collections": [c.name for c in info.collections],
}
except Exception as e:
return {"status": "error", "error": str(e)}
def ensure_collection_exists(collection_name: str) -> None:
"""Create collection if it does not exist."""
client = get_qdrant_client()
try:
client.get_collection(collection_name)
except (UnexpectedResponse, Exception):
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=QDRANT_VECTOR_SIZE, distance=Distance.COSINE),
)
print(f"✨ Created collection: {collection_name}")
def _initialize_vector_store(col_name: str) -> QdrantVectorStore:
"""Internal helper to create QdrantVectorStore instance based on mode."""
ensure_collection_exists(col_name)
if QDRANT_MODE == "local":
return QdrantVectorStore.from_existing_collection(
embedding=embedding_model,
collection_name=col_name,
path=QDRANT_STORAGE_PATH,
)
elif QDRANT_MODE == "cloud":
return QdrantVectorStore.from_existing_collection(
embedding=embedding_model,
collection_name=col_name,
url=QDRANT_HOST,
api_key=QDRANT_API_KEY,
)
else:
return QdrantVectorStore.from_existing_collection(
embedding=embedding_model,
collection_name=col_name,
url=f"http://{QDRANT_HOST}:{QDRANT_PORT}",
)
def get_vector_store(collection_name: str) -> QdrantVectorStore:
"""Public method to get vector store for a specific workspace."""
col_name = Config.get_collection_name(collection_name)
return _initialize_vector_store(col_name)
def add_to_vector_db(
chunks,
collection_name: str,
batch_size: int | None = None,
progress_callback: Callable[[int, int], None] | None = None,
):
"""Embed and add document chunks to the vector database."""
col_name = Config.get_collection_name(collection_name)
db = _initialize_vector_store(col_name)
texts = [chunk.page_content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
ids = [chunk.id for chunk in chunks]
bs = batch_size or EMBED_BATCH_SIZE
total = len(chunks)
print(f"πŸ“₯ Embedding {total} chunks into '{col_name}' (mode={QDRANT_MODE}, batch={bs}) …")
for i in range(0, total, bs):
end = min(i + bs, total)
db.add_texts(
texts=texts[i:end],
metadatas=metadatas[i:end],
ids=ids[i:end],
)
# Memory management
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print(f" βœ… Batch {i} β†’ {end}")
if progress_callback:
progress_callback(end, total)
return db
def get_retriever(collection_name: str, k: int = 5, fetch_k: int = 20):
"""Get a retriever object using MMR (Max Marginal Relevance)."""
db = get_vector_store(collection_name)
return db.as_retriever(
search_type="mmr",
search_kwargs={"k": k, "fetch_k": fetch_k},
)
def get_notebook_stats(notebook_name: str) -> dict:
"""Fetch statistics for a specific workspace collection."""
col_name = Config.get_collection_name(notebook_name)
stats = {"chunks": 0, "files": [], "size_mb": 0.0}
try:
client = get_qdrant_client()
info = client.get_collection(col_name)
stats["chunks"] = info.points_count or 0
if stats["chunks"] > 0:
points, _ = client.scroll(
collection_name=col_name,
limit=10000,
with_payload=True,
with_vectors=False,
)
sources = {
p.payload.get("metadata", {}).get("source", "")
for p in points
if p.payload
}
stats["files"] = sorted([s for s in sources if s])
if hasattr(info, "payload_schema"): # Check if size info is available
# disk_data_size is not always available in all client versions/modes
pass
except (UnexpectedResponse, Exception):
pass
return stats
def get_total_db_size() -> float:
"""Calculate total size of all collections (if supported by mode)."""
try:
client = get_qdrant_client()
collections = client.get_collections()
total = 0
for col in collections.collections:
try:
info = client.get_collection(col.name)
total += getattr(info, 'disk_data_size', 0) or 0
except Exception:
pass
return round(total / (1024 * 1024), 2)
except Exception:
return 0.0
def get_all_notebooks() -> list[str]:
"""List all workspace names by scanning collection prefixes."""
try:
client = get_qdrant_client()
collections = client.get_collections()
workspaces = []
for c in collections.collections:
if c.name.startswith("ws_"):
ws_name = c.name[3:]
workspaces.append(ws_name)
return workspaces
except Exception as e:
print(f"⚠️ List collections error: {e}")
return []
def delete_file_from_notebook(notebook_name: str, source_name: str) -> int:
"""Delete all chunks belonging to a specific file within a workspace."""
col_name = Config.get_collection_name(notebook_name)
try:
client = get_qdrant_client()
from qdrant_client.models import Filter, FieldCondition, MatchValue
points, _ = client.scroll(
collection_name=col_name,
scroll_filter=Filter(
must=[
FieldCondition(
key="metadata.source",
match=MatchValue(value=source_name),
)
]
),
limit=10000,
with_payload=False,
with_vectors=False,
)
if points:
ids = [p.id for p in points]
client.delete(
collection_name=col_name,
points_selector=ids,
)
print(f"πŸ—‘οΈ Deleted {len(ids)} chunks of '{source_name}' from {col_name}")
return len(ids)
return 0
except Exception as e:
print(f"❌ Delete file error: {e}")
return 0
def delete_notebook(notebook_name: str) -> bool:
"""Delete an entire workspace collection."""
col_name = Config.get_collection_name(notebook_name)
try:
client = get_qdrant_client()
client.delete_collection(col_name)
print(f"πŸ—‘οΈ Deleted collection: {col_name}")
return True
except Exception as e:
print(f"❌ Delete collection error: {e}")
return False