Spaces:
Runtime error
Runtime error
| import os | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| import logging | |
| import json | |
| from functools import lru_cache | |
| from sentence_transformers import SentenceTransformer | |
| from .embeddings_utils import load_faiss_index, load_metadata | |
| from .graph_index import HierarchicalGraphManager | |
| # Import your new graph manager | |
| # 1. Global Cache for the Embedding Model | |
| # This prevents the 2-5 second delay of reloading the model on every question. | |
| _MODEL_CACHE = {} | |
| def get_embedding_model(model_name: str): | |
| if model_name not in _MODEL_CACHE: | |
| # Use 'auto' to ensure it fits on whatever hardware is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _MODEL_CACHE[model_name] = SentenceTransformer(model_name, device=device, trust_remote_code=True) | |
| return _MODEL_CACHE[model_name] | |
| def query_index(index_path: str, meta_path: str, query: str, top_k: int = 10, model_name: str = "nomic-ai/nomic-embed-text-v1"): | |
| # Load Index and Meta | |
| index = load_faiss_index(index_path) | |
| meta = load_metadata(meta_path) | |
| # Get the "Warm" model | |
| model = get_embedding_model(model_name) | |
| # 2. Optimized Inference | |
| with torch.no_grad(): | |
| q_emb = model.encode([query], convert_to_numpy=True).astype('float32') | |
| # Unit normalization for Cosine Similarity (matches our optimized indexing) | |
| faiss.normalize_L2(q_emb) | |
| # 3. Fast FAISS Search | |
| distances, indices = index.search(q_emb, top_k) | |
| results = [] | |
| for dist, idx in zip(distances[0], indices[0]): | |
| if idx < 0 or idx >= len(meta): | |
| continue | |
| item = meta[idx] | |
| results.append({ | |
| "score": float(dist), | |
| "text": item.get("text"), | |
| "source": item.get("source"), | |
| "page": item.get("page"), | |
| "chunk_id": item.get("chunk_id"), | |
| }) | |
| return results | |
| def get_latest_session_paths(): | |
| """ | |
| Scans the data directory to find the most recently updated session folder. | |
| """ | |
| # Define your base search areas | |
| base_dirs = [Path("data/_staging"), Path("data/uploads")] | |
| all_sessions = [] | |
| for base in base_dirs: | |
| if base.exists(): | |
| # Get all subdirectories | |
| all_sessions.extend([d for d in base.iterdir() if d.is_dir()]) | |
| if not all_sessions: | |
| # Fallback to the root data directory if no session folders exist | |
| return Path("data/faiss.index"), Path("data/meta.pkl"), Path("data/graph_data.pkl") | |
| # Find the directory with the latest modification time | |
| latest_dir = max(all_sessions, key=lambda d: d.stat().st_mtime) | |
| logging.info(f"๐ Context Retriever: Pointing to latest session -> {latest_dir.name}") | |
| return ( | |
| latest_dir / "faiss.index", | |
| latest_dir / "meta.pkl", | |
| latest_dir / "graph_data.pkl" | |
| ) | |
| def retrieve_contexts(query: str, idx_path, meta_path, k: int = 10, return_with_sources: bool = True): | |
| """ | |
| Finds the latest index automatically and performs Vector Search. | |
| """ | |
| if not idx_path.exists() or not meta_path.exists(): | |
| logging.warning("โ ๏ธ No vector index found in the latest session.") | |
| return [] | |
| try: | |
| results = query_index(str(idx_path), str(meta_path), query, top_k=k) | |
| return results if return_with_sources else [r["text"] for r in results] | |
| except Exception as e: | |
| logging.error(f"Vector Retrieval Error: {e}") | |
| return [] | |
| def retrieve_hybrid_context(query: str, k: int = 5): | |
| """ | |
| The Master Coordinator: Fetches data and logs progress for debugging. | |
| """ | |
| # 1. Track Session Discovery | |
| idx_path, meta_path, graph_path = get_latest_session_paths() | |
| logging.info(f"๐ Session Target: {idx_path.parent.name}") | |
| # 2. Get Vector Context | |
| vector_results = retrieve_contexts(query, idx_path, meta_path, k=k, return_with_sources=True) | |
| logging.info(f"๐ Vector Search: Found {len(vector_results)} chunks") | |
| # 3. Get Graph Context | |
| gm = HierarchicalGraphManager(storage_path=str(graph_path)) | |
| graph_summaries = gm.get_relevant_community_summaries(query) | |
| logging.info(f"๐ธ๏ธ Graph Search: Found {len(graph_summaries)} community summaries") | |
| # 4. Prepare for Serialization | |
| data = { | |
| "vector_context": vector_results, | |
| "graph_context": graph_summaries, | |
| "metadata": vector_results | |
| } | |
| # 5. Log the final state before string conversion | |
| if not data["vector_context"] and not data["graph_context"]: | |
| logging.warning("โ ๏ธ Hybrid retrieval returned ZERO context.") | |
| else: | |
| logging.info("โ Hybrid context successfully compiled.") | |
| # Return as string with default=str to catch non-serializable metadata | |
| return json.dumps(data, indent=4, ensure_ascii=False, default=str) | |