Spaces:
Sleeping
Sleeping
Ara Yeroyan
commited on
Commit
·
f5df983
1
Parent(s):
26449fc
add src
Browse files- src/__init__.py +10 -0
- src/config/__init__.py +5 -0
- src/config/collections.json +22 -0
- src/config/loader.py +170 -0
- src/config/settings.yaml +92 -0
- src/llm/__init__.py +6 -0
- src/llm/adapters.py +409 -0
- src/llm/templates.py +232 -0
- src/loader.py +115 -0
- src/logging.py +193 -0
- src/pipeline.py +731 -0
- src/reporting/__init__.py +6 -0
- src/reporting/feedback_schema.py +196 -0
- src/reporting/metadata.py +216 -0
- src/reporting/service.py +144 -0
- src/reporting/snowflake_connector.py +305 -0
- src/retrieval/__init__.py +15 -0
- src/retrieval/colbert_cache.py +74 -0
- src/retrieval/context.py +881 -0
- src/retrieval/filter.py +975 -0
- src/retrieval/hybrid.py +479 -0
- src/vectorstore.py +266 -0
src/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audit QA Refactored Module
|
| 3 |
+
A modular and maintainable RAG pipeline for audit report analysis.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .pipeline import PipelineManager
|
| 7 |
+
from .config.loader import load_config
|
| 8 |
+
|
| 9 |
+
__version__ = "2.0.0"
|
| 10 |
+
__all__ = ["PipelineManager", "load_config"]
|
src/config/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management for Audit QA."""
|
| 2 |
+
|
| 3 |
+
from .loader import load_config, get_nested_config
|
| 4 |
+
|
| 5 |
+
__all__ = ["load_config", "get_nested_config"]
|
src/config/collections.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"docling": {
|
| 3 |
+
"model": "BAAI/bge-m3",
|
| 4 |
+
"description": "Default collection with BGE-M3 embedding model"
|
| 5 |
+
},
|
| 6 |
+
"modernbert-embed-base-akryl-matryoshka": {
|
| 7 |
+
"model": "Akryl/modernbert-embed-base-akryl-matryoshka",
|
| 8 |
+
"description": "ModernBERT embedding model with matryoshka representation"
|
| 9 |
+
},
|
| 10 |
+
"sentence-transformers-all-MiniLM-L6-v2": {
|
| 11 |
+
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
| 12 |
+
"description": "Sentence transformers MiniLM model"
|
| 13 |
+
},
|
| 14 |
+
"sentence-transformers-all-mpnet-base-v2": {
|
| 15 |
+
"model": "sentence-transformers/all-mpnet-base-v2",
|
| 16 |
+
"description": "Sentence transformers MPNet model"
|
| 17 |
+
},
|
| 18 |
+
"BAAI-bge-m3": {
|
| 19 |
+
"model": "BAAI/bge-m3",
|
| 20 |
+
"description": "BAAI BGE-M3 multilingual embedding model"
|
| 21 |
+
}
|
| 22 |
+
}
|
src/config/loader.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration loader for YAML settings."""
|
| 2 |
+
|
| 3 |
+
import yaml
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
def load_config(config_path: str = None) -> Dict[str, Any]:
|
| 13 |
+
"""
|
| 14 |
+
Load configuration from YAML file.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
config_path: Path to config file. If None, uses default settings.yaml
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Dictionary containing configuration settings
|
| 21 |
+
"""
|
| 22 |
+
if config_path is None:
|
| 23 |
+
# Default to settings.yaml in the same directory as this file
|
| 24 |
+
config_path = Path(__file__).parent / "settings.yaml"
|
| 25 |
+
|
| 26 |
+
config_path = Path(config_path)
|
| 27 |
+
|
| 28 |
+
if not config_path.exists():
|
| 29 |
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
| 30 |
+
|
| 31 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 32 |
+
content = f.read()
|
| 33 |
+
|
| 34 |
+
# Replace environment variables in the content
|
| 35 |
+
import os
|
| 36 |
+
import re
|
| 37 |
+
|
| 38 |
+
def replace_env_vars(match):
|
| 39 |
+
env_var = match.group(1)
|
| 40 |
+
return os.getenv(env_var, match.group(0)) # Return original if env var not found
|
| 41 |
+
|
| 42 |
+
# Replace ${VAR} patterns with environment variables
|
| 43 |
+
content = re.sub(r'\$\{([^}]+)\}', replace_env_vars, content)
|
| 44 |
+
|
| 45 |
+
config = yaml.safe_load(content)
|
| 46 |
+
|
| 47 |
+
# Override with environment variables if they exist
|
| 48 |
+
config = _override_with_env_vars(config)
|
| 49 |
+
|
| 50 |
+
return config
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _override_with_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 54 |
+
"""Override config values with environment variables where available."""
|
| 55 |
+
|
| 56 |
+
# Map environment variables to config paths
|
| 57 |
+
env_mappings = {
|
| 58 |
+
'QDRANT_URL': ['qdrant', 'url'],
|
| 59 |
+
'QDRANT_COLLECTION': ['qdrant', 'collection_name'],
|
| 60 |
+
'QDRANT_API_KEY': ['qdrant', 'api_key'],
|
| 61 |
+
'RETRIEVER_MODEL': ['retriever', 'model'],
|
| 62 |
+
'RANKER_MODEL': ['ranker', 'model'],
|
| 63 |
+
'READER_TYPE': ['reader', 'default_type'],
|
| 64 |
+
'MAX_TOKENS': ['reader', 'max_tokens'],
|
| 65 |
+
'MISTRAL_API_KEY': ['reader', 'MISTRAL', 'api_key'],
|
| 66 |
+
'OPENAI_API_KEY': ['reader', 'OPENAI', 'api_key'],
|
| 67 |
+
'NEBIUS_API_KEY': ['reader', 'INF_PROVIDERS', 'api_key'],
|
| 68 |
+
'NVIDIA_SERVER_API_KEY': ['reader', 'NVIDIA', 'api_key'],
|
| 69 |
+
'SERVERLESS_API_KEY': ['reader', 'SERVERLESS', 'api_key'],
|
| 70 |
+
'DEDICATED_API_KEY': ['reader', 'DEDICATED', 'api_key'],
|
| 71 |
+
'OPENROUTER_API_KEY': ['reader', 'OPENROUTER', 'api_key'],
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
for env_var, config_path in env_mappings.items():
|
| 75 |
+
env_value = os.getenv(env_var)
|
| 76 |
+
if env_value:
|
| 77 |
+
# Navigate to the nested config location
|
| 78 |
+
current = config
|
| 79 |
+
for key in config_path[:-1]:
|
| 80 |
+
if key not in current:
|
| 81 |
+
current[key] = {}
|
| 82 |
+
current = current[key]
|
| 83 |
+
|
| 84 |
+
# Set the final value, converting to appropriate type
|
| 85 |
+
final_key = config_path[-1]
|
| 86 |
+
if final_key in ['top_k', 'max_tokens', 'num_predict']:
|
| 87 |
+
current[final_key] = int(env_value)
|
| 88 |
+
elif final_key in ['normalize', 'prefer_grpc']:
|
| 89 |
+
current[final_key] = env_value.lower() in ('true', '1', 'yes')
|
| 90 |
+
elif final_key == 'temperature':
|
| 91 |
+
current[final_key] = float(env_value)
|
| 92 |
+
else:
|
| 93 |
+
current[final_key] = env_value
|
| 94 |
+
|
| 95 |
+
return config
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_nested_config(config: Dict[str, Any], path: str, default=None):
|
| 99 |
+
"""
|
| 100 |
+
Get a nested configuration value using dot notation.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
config: Configuration dictionary
|
| 104 |
+
path: Dot-separated path (e.g., 'reader.MISTRAL.model')
|
| 105 |
+
default: Default value if path not found
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Configuration value or default
|
| 109 |
+
"""
|
| 110 |
+
keys = path.split('.')
|
| 111 |
+
current = config
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
for key in keys:
|
| 115 |
+
current = current[key]
|
| 116 |
+
return current
|
| 117 |
+
except (KeyError, TypeError):
|
| 118 |
+
return default
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def load_collections_mapping() -> Dict[str, Dict[str, str]]:
|
| 122 |
+
"""Load collections mapping from JSON file."""
|
| 123 |
+
collections_file = Path(__file__).parent / "collections.json"
|
| 124 |
+
|
| 125 |
+
if not collections_file.exists():
|
| 126 |
+
# Return default mapping if file doesn't exist
|
| 127 |
+
return {
|
| 128 |
+
"docling": {
|
| 129 |
+
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
| 130 |
+
"description": "Default collection"
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
with open(collections_file, 'r') as f:
|
| 135 |
+
return json.load(f)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_embedding_model_for_collection(collection_name: str) -> Optional[str]:
|
| 139 |
+
"""Get embedding model for a specific collection name."""
|
| 140 |
+
collections = load_collections_mapping()
|
| 141 |
+
|
| 142 |
+
if collection_name in collections:
|
| 143 |
+
return collections[collection_name]["model"]
|
| 144 |
+
|
| 145 |
+
# Try to infer from collection name patterns
|
| 146 |
+
if "modernbert" in collection_name.lower():
|
| 147 |
+
return "Akryl/modernbert-embed-base-akryl-matryoshka"
|
| 148 |
+
elif "minilm" in collection_name.lower():
|
| 149 |
+
return "sentence-transformers/all-MiniLM-L6-v2"
|
| 150 |
+
elif "mpnet" in collection_name.lower():
|
| 151 |
+
return "sentence-transformers/all-mpnet-base-v2"
|
| 152 |
+
elif "bge" in collection_name.lower():
|
| 153 |
+
return "BAAI/bge-m3"
|
| 154 |
+
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_collection_info(collection_name: str) -> Dict[str, str]:
|
| 159 |
+
"""Get full collection information including model and description."""
|
| 160 |
+
collections = load_collections_mapping()
|
| 161 |
+
|
| 162 |
+
if collection_name in collections:
|
| 163 |
+
return collections[collection_name]
|
| 164 |
+
|
| 165 |
+
# Return inferred info for unknown collections
|
| 166 |
+
model = get_embedding_model_for_collection(collection_name)
|
| 167 |
+
return {
|
| 168 |
+
"model": model or "unknown",
|
| 169 |
+
"description": f"Auto-inferred collection: {collection_name}"
|
| 170 |
+
}
|
src/config/settings.yaml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Audit QA Configuration
|
| 2 |
+
# Converted from model_params.cfg to YAML format
|
| 3 |
+
|
| 4 |
+
qdrant:
|
| 5 |
+
# url: "http://10.1.4.192:8803"`
|
| 6 |
+
url: "https://2c6d0136-b6ca-4400-bac5-1703f58abc43.europe-west3-0.gcp.cloud.qdrant.io"
|
| 7 |
+
collection_name: "docling"
|
| 8 |
+
prefer_grpc: true
|
| 9 |
+
api_key: "${QDRANT_API_KEY}" # Load from environment variable
|
| 10 |
+
|
| 11 |
+
retriever:
|
| 12 |
+
model: "BAAI/bge-m3"
|
| 13 |
+
normalize: true
|
| 14 |
+
top_k: 20
|
| 15 |
+
|
| 16 |
+
retrieval:
|
| 17 |
+
use_reranking: true
|
| 18 |
+
reranker_model: "BAAI/bge-reranker-v2-m3"
|
| 19 |
+
reranker_top_k: 5
|
| 20 |
+
|
| 21 |
+
ranker:
|
| 22 |
+
model: "BAAI/bge-reranker-v2-m3"
|
| 23 |
+
top_k: 5
|
| 24 |
+
|
| 25 |
+
bm25:
|
| 26 |
+
top_k: 20
|
| 27 |
+
|
| 28 |
+
hybrid:
|
| 29 |
+
default_mode: "vector_only" # Options: vector_only, sparse_only, hybrid
|
| 30 |
+
default_alpha: 0.5 # Weight for vector scores (0.5 = equal weight)
|
| 31 |
+
|
| 32 |
+
reader:
|
| 33 |
+
default_type: "OPENAI"
|
| 34 |
+
max_tokens: 768
|
| 35 |
+
|
| 36 |
+
# Different LLM provider configurations
|
| 37 |
+
INF_PROVIDERS:
|
| 38 |
+
model: "meta-llama/Llama-3.1-8B-Instruct"
|
| 39 |
+
provider: "nebius"
|
| 40 |
+
|
| 41 |
+
# Not working
|
| 42 |
+
NVIDIA:
|
| 43 |
+
model: "meta-llama/Llama-3.1-8B-Instruct"
|
| 44 |
+
endpoint: "https://huggingface.co/api/integrations/dgx/v1"
|
| 45 |
+
|
| 46 |
+
# Not working
|
| 47 |
+
DEDICATED:
|
| 48 |
+
model: "meta-llama/Llama-3.1-8B-Instruct"
|
| 49 |
+
endpoint: "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud"
|
| 50 |
+
|
| 51 |
+
MISTRAL:
|
| 52 |
+
model: "mistral-medium-latest"
|
| 53 |
+
|
| 54 |
+
OPENAI:
|
| 55 |
+
model: "gpt-4o-mini"
|
| 56 |
+
|
| 57 |
+
OLLAMA:
|
| 58 |
+
model: "mistral-small3.1:24b-instruct-2503-q8_0"
|
| 59 |
+
base_url: "http://10.1.4.192:11434/"
|
| 60 |
+
temperature: 0.8
|
| 61 |
+
num_predict: 256
|
| 62 |
+
|
| 63 |
+
OPENROUTER:
|
| 64 |
+
model: "moonshotai/kimi-k2:free"
|
| 65 |
+
base_url: "https://openrouter.ai/api/v1"
|
| 66 |
+
temperature: 0.7
|
| 67 |
+
max_tokens: 1000
|
| 68 |
+
# site_url: "https://your-site.com" # optional, for OpenRouter ranking
|
| 69 |
+
# site_name: "Your Site Name" # optional, for OpenRouter ranking
|
| 70 |
+
|
| 71 |
+
app:
|
| 72 |
+
dropdown_default: "Annual Consolidated OAG 2024"
|
| 73 |
+
|
| 74 |
+
# File paths
|
| 75 |
+
paths:
|
| 76 |
+
chunks_file: "reports/docling_chunks.json"
|
| 77 |
+
reports_dir: "reports"
|
| 78 |
+
|
| 79 |
+
# Feature toggles
|
| 80 |
+
features:
|
| 81 |
+
enable_session: true
|
| 82 |
+
enable_logging: true
|
| 83 |
+
|
| 84 |
+
# Logging and HuggingFace scheduler configuration
|
| 85 |
+
logging:
|
| 86 |
+
json_dataset_dir: "json_dataset"
|
| 87 |
+
huggingface:
|
| 88 |
+
repo_id: "GIZ/spaces_logs"
|
| 89 |
+
repo_type: "dataset"
|
| 90 |
+
folder_path: "json_dataset"
|
| 91 |
+
path_in_repo: "audit_chatbot"
|
| 92 |
+
token_env_var: "SPACES_LOG"
|
src/llm/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM adapters and utilities."""
|
| 2 |
+
|
| 3 |
+
from .adapters import LLMRegistry, get_llm_client
|
| 4 |
+
from .templates import get_message_template, PromptTemplate, create_audit_prompt
|
| 5 |
+
|
| 6 |
+
__all__ = ["LLMRegistry", "get_llm_client", "get_message_template", "PromptTemplate", "create_audit_prompt"]
|
src/llm/adapters.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM client adapters for different providers."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Any, List, Optional, Union
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
# LangChain imports
|
| 8 |
+
from langchain_mistralai.chat_models import ChatMistralAI
|
| 9 |
+
from langchain_openai.chat_models import ChatOpenAI
|
| 10 |
+
from langchain_ollama import ChatOllama
|
| 11 |
+
|
| 12 |
+
# Legacy client dependencies
|
| 13 |
+
from huggingface_hub import InferenceClient
|
| 14 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 15 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
| 16 |
+
from langchain_community.chat_models.huggingface import ChatHuggingFace
|
| 17 |
+
|
| 18 |
+
# Configuration loader
|
| 19 |
+
from ..config.loader import load_config
|
| 20 |
+
|
| 21 |
+
# Load configuration once at module level
|
| 22 |
+
_config = load_config()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Legacy client factory functions (inlined from auditqa_old.reader)
|
| 26 |
+
def _create_inf_provider_client():
|
| 27 |
+
"""Create INF_PROVIDERS client."""
|
| 28 |
+
reader_config = _config.get("reader", {})
|
| 29 |
+
inf_config = reader_config.get("INF_PROVIDERS", {})
|
| 30 |
+
|
| 31 |
+
api_key = inf_config.get("api_key")
|
| 32 |
+
if not api_key:
|
| 33 |
+
raise ValueError("INF_PROVIDERS api_key not found in configuration")
|
| 34 |
+
|
| 35 |
+
provider = inf_config.get("provider")
|
| 36 |
+
if not provider:
|
| 37 |
+
raise ValueError("INF_PROVIDERS provider not found in configuration")
|
| 38 |
+
|
| 39 |
+
return InferenceClient(
|
| 40 |
+
provider=provider,
|
| 41 |
+
api_key=api_key,
|
| 42 |
+
bill_to="GIZ",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _create_nvidia_client():
|
| 47 |
+
"""Create NVIDIA client."""
|
| 48 |
+
reader_config = _config.get("reader", {})
|
| 49 |
+
nvidia_config = reader_config.get("NVIDIA", {})
|
| 50 |
+
|
| 51 |
+
api_key = nvidia_config.get("api_key")
|
| 52 |
+
if not api_key:
|
| 53 |
+
raise ValueError("NVIDIA api_key not found in configuration")
|
| 54 |
+
|
| 55 |
+
endpoint = nvidia_config.get("endpoint")
|
| 56 |
+
if not endpoint:
|
| 57 |
+
raise ValueError("NVIDIA endpoint not found in configuration")
|
| 58 |
+
|
| 59 |
+
return InferenceClient(
|
| 60 |
+
base_url=endpoint,
|
| 61 |
+
api_key=api_key
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _create_serverless_client():
|
| 66 |
+
"""Create serverless API client."""
|
| 67 |
+
reader_config = _config.get("reader", {})
|
| 68 |
+
serverless_config = reader_config.get("SERVERLESS", {})
|
| 69 |
+
|
| 70 |
+
api_key = serverless_config.get("api_key")
|
| 71 |
+
if not api_key:
|
| 72 |
+
raise ValueError("SERVERLESS api_key not found in configuration")
|
| 73 |
+
|
| 74 |
+
model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct")
|
| 75 |
+
|
| 76 |
+
return InferenceClient(
|
| 77 |
+
model=model_id,
|
| 78 |
+
api_key=api_key,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _create_dedicated_endpoint_client():
|
| 83 |
+
"""Create dedicated endpoint client."""
|
| 84 |
+
reader_config = _config.get("reader", {})
|
| 85 |
+
dedicated_config = reader_config.get("DEDICATED", {})
|
| 86 |
+
|
| 87 |
+
api_key = dedicated_config.get("api_key")
|
| 88 |
+
if not api_key:
|
| 89 |
+
raise ValueError("DEDICATED api_key not found in configuration")
|
| 90 |
+
|
| 91 |
+
endpoint = dedicated_config.get("endpoint")
|
| 92 |
+
if not endpoint:
|
| 93 |
+
raise ValueError("DEDICATED endpoint not found in configuration")
|
| 94 |
+
|
| 95 |
+
max_tokens = dedicated_config.get("max_tokens", 768)
|
| 96 |
+
|
| 97 |
+
# Set up the streaming callback handler
|
| 98 |
+
callback = StreamingStdOutCallbackHandler()
|
| 99 |
+
|
| 100 |
+
# Initialize the HuggingFaceEndpoint with streaming enabled
|
| 101 |
+
llm_qa = HuggingFaceEndpoint(
|
| 102 |
+
endpoint_url=endpoint,
|
| 103 |
+
max_new_tokens=int(max_tokens),
|
| 104 |
+
repetition_penalty=1.03,
|
| 105 |
+
timeout=70,
|
| 106 |
+
huggingfacehub_api_token=api_key,
|
| 107 |
+
streaming=True,
|
| 108 |
+
callbacks=[callback]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Create a ChatHuggingFace instance with the streaming-enabled endpoint
|
| 112 |
+
return ChatHuggingFace(llm=llm_qa)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class LLMResponse:
|
| 117 |
+
"""Standardized LLM response format."""
|
| 118 |
+
content: str
|
| 119 |
+
model: str
|
| 120 |
+
provider: str
|
| 121 |
+
metadata: Dict[str, Any] = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class BaseLLMAdapter(ABC):
|
| 125 |
+
"""Base class for LLM adapters."""
|
| 126 |
+
|
| 127 |
+
def __init__(self, config: Dict[str, Any]):
|
| 128 |
+
self.config = config
|
| 129 |
+
|
| 130 |
+
@abstractmethod
|
| 131 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 132 |
+
"""Generate response from messages."""
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
@abstractmethod
|
| 136 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 137 |
+
"""Generate streaming response from messages."""
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class MistralAdapter(BaseLLMAdapter):
|
| 142 |
+
"""Adapter for Mistral AI models."""
|
| 143 |
+
|
| 144 |
+
def __init__(self, config: Dict[str, Any]):
|
| 145 |
+
super().__init__(config)
|
| 146 |
+
self.model = ChatMistralAI(
|
| 147 |
+
model=config.get("model", "mistral-medium-latest")
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 151 |
+
"""Generate response using Mistral."""
|
| 152 |
+
response = self.model.invoke(messages)
|
| 153 |
+
|
| 154 |
+
return LLMResponse(
|
| 155 |
+
content=response.content,
|
| 156 |
+
model=self.config.get("model", "mistral-medium-latest"),
|
| 157 |
+
provider="mistral",
|
| 158 |
+
metadata={"usage": getattr(response, 'usage_metadata', {})}
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 162 |
+
"""Generate streaming response using Mistral."""
|
| 163 |
+
for chunk in self.model.stream(messages):
|
| 164 |
+
if chunk.content:
|
| 165 |
+
yield chunk.content
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class OpenAIAdapter(BaseLLMAdapter):
|
| 169 |
+
"""Adapter for OpenAI models."""
|
| 170 |
+
|
| 171 |
+
def __init__(self, config: Dict[str, Any]):
|
| 172 |
+
super().__init__(config)
|
| 173 |
+
self.model = ChatOpenAI(
|
| 174 |
+
model=config.get("model", "gpt-4o-mini")
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 178 |
+
"""Generate response using OpenAI."""
|
| 179 |
+
response = self.model.invoke(messages)
|
| 180 |
+
|
| 181 |
+
return LLMResponse(
|
| 182 |
+
content=response.content,
|
| 183 |
+
model=self.config.get("model", "gpt-4o-mini"),
|
| 184 |
+
provider="openai",
|
| 185 |
+
metadata={"usage": getattr(response, 'usage_metadata', {})}
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 189 |
+
"""Generate streaming response using OpenAI."""
|
| 190 |
+
for chunk in self.model.stream(messages):
|
| 191 |
+
if chunk.content:
|
| 192 |
+
yield chunk.content
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class OllamaAdapter(BaseLLMAdapter):
|
| 196 |
+
"""Adapter for Ollama models."""
|
| 197 |
+
|
| 198 |
+
def __init__(self, config: Dict[str, Any]):
|
| 199 |
+
super().__init__(config)
|
| 200 |
+
self.model = ChatOllama(
|
| 201 |
+
model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
|
| 202 |
+
base_url=config.get("base_url", "http://localhost:11434/"),
|
| 203 |
+
temperature=config.get("temperature", 0.8),
|
| 204 |
+
num_predict=config.get("num_predict", 256)
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 208 |
+
"""Generate response using Ollama."""
|
| 209 |
+
response = self.model.invoke(messages)
|
| 210 |
+
|
| 211 |
+
return LLMResponse(
|
| 212 |
+
content=response.content,
|
| 213 |
+
model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
|
| 214 |
+
provider="ollama",
|
| 215 |
+
metadata={}
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 219 |
+
"""Generate streaming response using Ollama."""
|
| 220 |
+
for chunk in self.model.stream(messages):
|
| 221 |
+
if chunk.content:
|
| 222 |
+
yield chunk.content
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class OpenRouterAdapter(BaseLLMAdapter):
|
| 226 |
+
"""Adapter for OpenRouter models."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, config: Dict[str, Any]):
|
| 229 |
+
super().__init__(config)
|
| 230 |
+
|
| 231 |
+
# Prepare custom headers for OpenRouter (optional)
|
| 232 |
+
headers = {}
|
| 233 |
+
if config.get("site_url"):
|
| 234 |
+
headers["HTTP-Referer"] = config["site_url"]
|
| 235 |
+
if config.get("site_name"):
|
| 236 |
+
headers["X-Title"] = config["site_name"]
|
| 237 |
+
|
| 238 |
+
# Initialize ChatOpenAI with OpenRouter configuration
|
| 239 |
+
self.model = ChatOpenAI(
|
| 240 |
+
model=config.get("model", "openai/gpt-3.5-turbo"),
|
| 241 |
+
api_key=config.get("api_key"),
|
| 242 |
+
base_url=config.get("base_url", "https://openrouter.ai/api/v1"),
|
| 243 |
+
default_headers= headers if headers else {},
|
| 244 |
+
temperature=config.get("temperature", 0.7),
|
| 245 |
+
max_tokens=config.get("max_tokens", 1000)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 249 |
+
"""Generate response using OpenRouter."""
|
| 250 |
+
response = self.model.invoke(messages)
|
| 251 |
+
|
| 252 |
+
return LLMResponse(
|
| 253 |
+
content=response.content,
|
| 254 |
+
model=self.config.get("model", "openai/gpt-3.5-turbo"),
|
| 255 |
+
provider="openrouter",
|
| 256 |
+
metadata={"usage": getattr(response, 'usage_metadata', {})}
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 260 |
+
"""Generate streaming response using OpenRouter."""
|
| 261 |
+
for chunk in self.model.stream(messages):
|
| 262 |
+
if chunk.content:
|
| 263 |
+
yield chunk.content
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class LegacyAdapter(BaseLLMAdapter):
|
| 267 |
+
"""Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.)."""
|
| 268 |
+
|
| 269 |
+
def __init__(self, config: Dict[str, Any], client_type: str):
|
| 270 |
+
super().__init__(config)
|
| 271 |
+
self.client_type = client_type
|
| 272 |
+
self.client = self._create_client()
|
| 273 |
+
|
| 274 |
+
def _create_client(self):
|
| 275 |
+
"""Create legacy client based on type."""
|
| 276 |
+
if self.client_type == "INF_PROVIDERS":
|
| 277 |
+
return _create_inf_provider_client()
|
| 278 |
+
elif self.client_type == "NVIDIA":
|
| 279 |
+
return _create_nvidia_client()
|
| 280 |
+
elif self.client_type == "DEDICATED":
|
| 281 |
+
return _create_dedicated_endpoint_client()
|
| 282 |
+
else: # SERVERLESS
|
| 283 |
+
return _create_serverless_client()
|
| 284 |
+
|
| 285 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 286 |
+
"""Generate response using legacy client."""
|
| 287 |
+
max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768))
|
| 288 |
+
|
| 289 |
+
if self.client_type == "INF_PROVIDERS":
|
| 290 |
+
response = self.client.chat.completions.create(
|
| 291 |
+
model=self.config.get("model"),
|
| 292 |
+
messages=messages,
|
| 293 |
+
max_tokens=max_tokens
|
| 294 |
+
)
|
| 295 |
+
content = response.choices[0].message.content
|
| 296 |
+
|
| 297 |
+
elif self.client_type == "NVIDIA":
|
| 298 |
+
response = self.client.chat_completion(
|
| 299 |
+
model=self.config.get("model"),
|
| 300 |
+
messages=messages,
|
| 301 |
+
max_tokens=max_tokens
|
| 302 |
+
)
|
| 303 |
+
content = response.choices[0].message.content
|
| 304 |
+
|
| 305 |
+
else: # DEDICATED or SERVERLESS
|
| 306 |
+
response = self.client.chat_completion(
|
| 307 |
+
messages=messages,
|
| 308 |
+
max_tokens=max_tokens
|
| 309 |
+
)
|
| 310 |
+
content = response.choices[0].message.content
|
| 311 |
+
|
| 312 |
+
return LLMResponse(
|
| 313 |
+
content=content,
|
| 314 |
+
model=self.config.get("model", "unknown"),
|
| 315 |
+
provider=self.client_type.lower(),
|
| 316 |
+
metadata={}
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 320 |
+
"""Generate streaming response using legacy client."""
|
| 321 |
+
# Legacy clients may not support streaming in the same way
|
| 322 |
+
# This is a simplified implementation
|
| 323 |
+
response = self.generate(messages, **kwargs)
|
| 324 |
+
words = response.content.split()
|
| 325 |
+
for word in words:
|
| 326 |
+
yield word + " "
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class LLMRegistry:
|
| 330 |
+
"""Registry for managing different LLM adapters."""
|
| 331 |
+
|
| 332 |
+
def __init__(self):
|
| 333 |
+
self.adapters = {}
|
| 334 |
+
self.adapter_configs = {}
|
| 335 |
+
|
| 336 |
+
def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]):
|
| 337 |
+
"""Register an LLM adapter (lazy instantiation)."""
|
| 338 |
+
self.adapter_configs[name] = (adapter_class, config)
|
| 339 |
+
|
| 340 |
+
def get_adapter(self, name: str) -> BaseLLMAdapter:
|
| 341 |
+
"""Get an LLM adapter by name (lazy instantiation)."""
|
| 342 |
+
if name not in self.adapter_configs:
|
| 343 |
+
raise ValueError(f"Unknown LLM adapter: {name}")
|
| 344 |
+
|
| 345 |
+
# Lazy instantiation - only create when needed
|
| 346 |
+
if name not in self.adapters:
|
| 347 |
+
adapter_class, config = self.adapter_configs[name]
|
| 348 |
+
self.adapters[name] = adapter_class(config)
|
| 349 |
+
|
| 350 |
+
return self.adapters[name]
|
| 351 |
+
|
| 352 |
+
def list_adapters(self) -> List[str]:
|
| 353 |
+
"""List available adapter names."""
|
| 354 |
+
return list(self.adapter_configs.keys())
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry:
|
| 358 |
+
"""
|
| 359 |
+
Create and populate LLM registry from configuration.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
config: Configuration dictionary
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Populated LLMRegistry
|
| 366 |
+
"""
|
| 367 |
+
registry = LLMRegistry()
|
| 368 |
+
reader_config = config.get("reader", {})
|
| 369 |
+
|
| 370 |
+
# Register simple adapters
|
| 371 |
+
if "MISTRAL" in reader_config:
|
| 372 |
+
registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"])
|
| 373 |
+
|
| 374 |
+
if "OPENAI" in reader_config:
|
| 375 |
+
registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"])
|
| 376 |
+
|
| 377 |
+
if "OLLAMA" in reader_config:
|
| 378 |
+
registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"])
|
| 379 |
+
|
| 380 |
+
if "OPENROUTER" in reader_config:
|
| 381 |
+
registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"])
|
| 382 |
+
|
| 383 |
+
# Register legacy adapters
|
| 384 |
+
# legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"]
|
| 385 |
+
legacy_types = ["INF_PROVIDERS"]
|
| 386 |
+
for legacy_type in legacy_types:
|
| 387 |
+
if legacy_type in reader_config:
|
| 388 |
+
registry.register_adapter(
|
| 389 |
+
legacy_type.lower(),
|
| 390 |
+
lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt),
|
| 391 |
+
reader_config[legacy_type]
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return registry
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter:
|
| 398 |
+
"""
|
| 399 |
+
Get LLM client for specified provider.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
provider: Provider name (mistral, openai, ollama, etc.)
|
| 403 |
+
config: Configuration dictionary
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
LLM adapter instance
|
| 407 |
+
"""
|
| 408 |
+
registry = create_llm_registry(config)
|
| 409 |
+
return registry.get_adapter(provider)
|
src/llm/templates.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM prompt templates and message formatting utilities."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Dict, Any, Union
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from langchain.schema import SystemMessage, HumanMessage
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class PromptTemplate:
|
| 10 |
+
"""Template for managing prompts with variables."""
|
| 11 |
+
|
| 12 |
+
system_prompt: str
|
| 13 |
+
user_prompt_template: str
|
| 14 |
+
|
| 15 |
+
def format(self, **kwargs) -> tuple:
|
| 16 |
+
"""Format the template with provided variables."""
|
| 17 |
+
formatted_user = self.user_prompt_template.format(**kwargs)
|
| 18 |
+
return self.system_prompt, formatted_user
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Default system prompt for audit Q&A
|
| 22 |
+
DEFAULT_AUDIT_SYSTEM_PROMPT = """
|
| 23 |
+
You are AuditQ&A, an AI Assistant for audit reports. Answer questions directly and factually based on the provided context.
|
| 24 |
+
|
| 25 |
+
Guidelines:
|
| 26 |
+
- Answer directly and concisely (2-3 sentences maximum)
|
| 27 |
+
- Use specific facts and numbers from the context
|
| 28 |
+
- Cite sources using [Doc i] format
|
| 29 |
+
- Be factual, not opinionated
|
| 30 |
+
- Avoid phrases like "From my point of view", "I think", "It seems"
|
| 31 |
+
|
| 32 |
+
Examples:
|
| 33 |
+
|
| 34 |
+
Query: "What challenges arise from contradictory PDM implementation guidelines?"
|
| 35 |
+
Context: [Retrieved documents about PDM guidelines contradictions]
|
| 36 |
+
Answer: "Contradictory PDM implementation guidelines cause challenges during implementation, as entities receive numerous and often conflicting directives from different authorities. For example, guidelines on transfer of funds to PDM SACCOs differ between the PDM Secretariat and PSST, and there are conflicting directives on fund diversion from various authorities."
|
| 37 |
+
|
| 38 |
+
Query: "What was the supplementary funding obtained for the wage budget?"
|
| 39 |
+
Context: [Retrieved documents about wage budget funding]
|
| 40 |
+
Answer: "The supplementary funding obtained for the wage budget was UGX.2,208,040,656."
|
| 41 |
+
|
| 42 |
+
Now answer the following question based on the provided context:
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
# Default user prompt template
|
| 46 |
+
DEFAULT_USER_PROMPT_TEMPLATE = """Passages:
|
| 47 |
+
{context}
|
| 48 |
+
-----------------------
|
| 49 |
+
Question: {question} - Explained to audit expert
|
| 50 |
+
Answer in english with the passages citations:
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def create_audit_prompt(context_list: List[str], query: str) -> List[Dict[str, str]]:
|
| 55 |
+
"""
|
| 56 |
+
Create audit Q&A prompt messages from context and query.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
context_list: List of context passages
|
| 60 |
+
query: User query
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
List of message dictionaries for LLM
|
| 64 |
+
"""
|
| 65 |
+
# Join context passages with numbering
|
| 66 |
+
numbered_context = []
|
| 67 |
+
for i, passage in enumerate(context_list, 1):
|
| 68 |
+
numbered_context.append(f"Doc {i}: {passage}")
|
| 69 |
+
|
| 70 |
+
context_str = "\n\n".join(numbered_context)
|
| 71 |
+
|
| 72 |
+
# Format user prompt
|
| 73 |
+
user_prompt = DEFAULT_USER_PROMPT_TEMPLATE.format(
|
| 74 |
+
context=context_str,
|
| 75 |
+
question=query
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Return as message format
|
| 79 |
+
messages = [
|
| 80 |
+
{"role": "system", "content": DEFAULT_AUDIT_SYSTEM_PROMPT},
|
| 81 |
+
{"role": "user", "content": user_prompt}
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
return messages
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_message_template(
|
| 88 |
+
provider_type: str,
|
| 89 |
+
system_prompt: str,
|
| 90 |
+
user_prompt: str
|
| 91 |
+
) -> List[Union[Dict[str, str], SystemMessage, HumanMessage]]:
|
| 92 |
+
"""
|
| 93 |
+
Get message template based on LLM provider type.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
provider_type: Type of LLM provider
|
| 97 |
+
system_prompt: System prompt content
|
| 98 |
+
user_prompt: User prompt content
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
List of messages in the appropriate format for the provider
|
| 102 |
+
"""
|
| 103 |
+
provider_type = provider_type.upper()
|
| 104 |
+
|
| 105 |
+
if provider_type in ['NVIDIA', 'INF_PROVIDERS', 'MISTRAL', 'OPENAI', 'OPENROUTER']:
|
| 106 |
+
# Dictionary format for API-based providers
|
| 107 |
+
messages = [
|
| 108 |
+
{"role": "system", "content": system_prompt},
|
| 109 |
+
{"role": "user", "content": user_prompt}
|
| 110 |
+
]
|
| 111 |
+
elif provider_type in ['DEDICATED', 'SERVERLESS', 'OLLAMA']:
|
| 112 |
+
# LangChain message objects for local/dedicated providers
|
| 113 |
+
messages = [
|
| 114 |
+
SystemMessage(content=system_prompt),
|
| 115 |
+
HumanMessage(content=user_prompt)
|
| 116 |
+
]
|
| 117 |
+
else:
|
| 118 |
+
# Default to dictionary format
|
| 119 |
+
messages = [
|
| 120 |
+
{"role": "system", "content": system_prompt},
|
| 121 |
+
{"role": "user", "content": user_prompt}
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
return messages
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def create_custom_prompt_template(
|
| 128 |
+
system_prompt: str,
|
| 129 |
+
user_template: str
|
| 130 |
+
) -> PromptTemplate:
|
| 131 |
+
"""
|
| 132 |
+
Create a custom prompt template.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
system_prompt: System prompt content
|
| 136 |
+
user_template: User prompt template with placeholders
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
PromptTemplate instance
|
| 140 |
+
"""
|
| 141 |
+
return PromptTemplate(
|
| 142 |
+
system_prompt=system_prompt,
|
| 143 |
+
user_prompt_template=user_template
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def create_evaluation_prompt(context_list: List[str], query: str, expected_answer: str) -> List[Dict[str, str]]:
|
| 148 |
+
"""
|
| 149 |
+
Create prompt for evaluation purposes with expected answer.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
context_list: List of context passages
|
| 153 |
+
query: User query
|
| 154 |
+
expected_answer: Expected/ground truth answer
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
List of message dictionaries for evaluation
|
| 158 |
+
"""
|
| 159 |
+
# Join context passages
|
| 160 |
+
context_str = "\n\n".join([f"Doc {i}: {passage}" for i, passage in enumerate(context_list, 1)])
|
| 161 |
+
|
| 162 |
+
evaluation_system_prompt = """
|
| 163 |
+
You are an evaluation assistant. Given context passages, a question, and an expected answer,
|
| 164 |
+
evaluate how well the provided context supports answering the question accurately.
|
| 165 |
+
|
| 166 |
+
Provide your evaluation focusing on:
|
| 167 |
+
1. Relevance of the context to the question
|
| 168 |
+
2. Completeness of information needed to answer
|
| 169 |
+
3. Quality and accuracy of supporting details
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
user_prompt = f"""Context Passages:
|
| 173 |
+
{context_str}
|
| 174 |
+
|
| 175 |
+
Question: {query}
|
| 176 |
+
Expected Answer: {expected_answer}
|
| 177 |
+
|
| 178 |
+
Evaluation:"""
|
| 179 |
+
|
| 180 |
+
return [
|
| 181 |
+
{"role": "system", "content": evaluation_system_prompt},
|
| 182 |
+
{"role": "user", "content": user_prompt}
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_prompt_variants() -> Dict[str, PromptTemplate]:
|
| 187 |
+
"""
|
| 188 |
+
Get different prompt template variants for testing.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Dictionary of named prompt templates
|
| 192 |
+
"""
|
| 193 |
+
variants = {
|
| 194 |
+
"standard": create_custom_prompt_template(
|
| 195 |
+
DEFAULT_AUDIT_SYSTEM_PROMPT,
|
| 196 |
+
DEFAULT_USER_PROMPT_TEMPLATE
|
| 197 |
+
),
|
| 198 |
+
|
| 199 |
+
"concise": create_custom_prompt_template(
|
| 200 |
+
"""You are an audit report AI assistant. Provide clear, concise answers based on the given context passages. Always cite sources using [Doc i] format.""",
|
| 201 |
+
"""Context:\n{context}\n\nQuestion: {question}\nAnswer:"""
|
| 202 |
+
),
|
| 203 |
+
|
| 204 |
+
"detailed": create_custom_prompt_template(
|
| 205 |
+
DEFAULT_AUDIT_SYSTEM_PROMPT + """\n\nAdditional Instructions:
|
| 206 |
+
- Provide detailed explanations with specific examples
|
| 207 |
+
- Include relevant numbers, dates, and financial figures when available
|
| 208 |
+
- Structure your response with clear headings when appropriate
|
| 209 |
+
- Explain the significance of findings in the context of governance and accountability""",
|
| 210 |
+
DEFAULT_USER_PROMPT_TEMPLATE
|
| 211 |
+
)
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
return variants
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Backward compatibility function
|
| 218 |
+
def format_context_with_citations(context_list: List[str]) -> str:
|
| 219 |
+
"""
|
| 220 |
+
Format context list with document citations.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
context_list: List of context passages
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Formatted context string with citations
|
| 227 |
+
"""
|
| 228 |
+
formatted_passages = []
|
| 229 |
+
for i, passage in enumerate(context_list, 1):
|
| 230 |
+
formatted_passages.append(f"Doc {i}: {passage}")
|
| 231 |
+
|
| 232 |
+
return "\n\n".join(formatted_passages)
|
src/loader.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loading utilities for chunks and JSON files."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
from langchain.docstore.document import Document
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_json(filepath: Path | str) -> List[Dict[str, Any]]:
|
| 10 |
+
"""
|
| 11 |
+
Load JSON data from file.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
filepath: Path to JSON file
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
List of dictionaries containing the JSON data
|
| 18 |
+
"""
|
| 19 |
+
filepath = Path(filepath)
|
| 20 |
+
|
| 21 |
+
if not filepath.exists():
|
| 22 |
+
raise FileNotFoundError(f"JSON file not found: {filepath}")
|
| 23 |
+
|
| 24 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 25 |
+
data = json.load(f)
|
| 26 |
+
|
| 27 |
+
return data
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def open_file(filepath: Path | str) -> str:
|
| 31 |
+
"""
|
| 32 |
+
Open and read a text file.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
filepath: Path to text file
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
File contents as string
|
| 39 |
+
"""
|
| 40 |
+
filepath = Path(filepath)
|
| 41 |
+
|
| 42 |
+
if not filepath.exists():
|
| 43 |
+
raise FileNotFoundError(f"File not found: {filepath}")
|
| 44 |
+
|
| 45 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 46 |
+
content = f.read()
|
| 47 |
+
|
| 48 |
+
return content
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_chunks(chunks_file: Path | str = None) -> List[Dict[str, Any]]:
|
| 52 |
+
"""
|
| 53 |
+
Load document chunks from JSON file.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
chunks_file: Path to chunks JSON file. If None, uses default path.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
List of chunk dictionaries
|
| 60 |
+
"""
|
| 61 |
+
if chunks_file is None:
|
| 62 |
+
chunks_file = Path("reports/docling_chunks.json")
|
| 63 |
+
|
| 64 |
+
return load_json(chunks_file)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def chunks_to_documents(chunks: List[Dict[str, Any]]) -> List[Document]:
|
| 68 |
+
"""
|
| 69 |
+
Convert chunk dictionaries to LangChain Document objects.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
chunks: List of chunk dictionaries
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
List of Document objects
|
| 76 |
+
"""
|
| 77 |
+
documents = []
|
| 78 |
+
|
| 79 |
+
for chunk in chunks:
|
| 80 |
+
doc = Document(
|
| 81 |
+
page_content=chunk.get("content", ""),
|
| 82 |
+
metadata=chunk.get("metadata", {})
|
| 83 |
+
)
|
| 84 |
+
documents.append(doc)
|
| 85 |
+
|
| 86 |
+
return documents
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def validate_chunks(chunks: List[Dict[str, Any]]) -> bool:
|
| 90 |
+
"""
|
| 91 |
+
Validate that chunks have required fields.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
chunks: List of chunk dictionaries
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
True if valid, raises ValueError if invalid
|
| 98 |
+
"""
|
| 99 |
+
required_fields = ["content", "metadata"]
|
| 100 |
+
|
| 101 |
+
for i, chunk in enumerate(chunks):
|
| 102 |
+
for field in required_fields:
|
| 103 |
+
if field not in chunk:
|
| 104 |
+
raise ValueError(f"Chunk {i} missing required field: {field}")
|
| 105 |
+
|
| 106 |
+
# Validate metadata has required fields
|
| 107 |
+
metadata = chunk["metadata"]
|
| 108 |
+
if not isinstance(metadata, dict):
|
| 109 |
+
raise ValueError(f"Chunk {i} metadata must be a dictionary")
|
| 110 |
+
|
| 111 |
+
# Check for common metadata fields
|
| 112 |
+
if "filename" not in metadata:
|
| 113 |
+
raise ValueError(f"Chunk {i} metadata missing 'filename' field")
|
| 114 |
+
|
| 115 |
+
return True
|
src/logging.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging utilities (placeholder for legacy compatibility)."""
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from uuid import uuid4
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from threading import Lock
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
|
| 10 |
+
from .config import load_config
|
| 11 |
+
|
| 12 |
+
def save_logs(
|
| 13 |
+
scheduler=None,
|
| 14 |
+
json_dataset_path: Path = None,
|
| 15 |
+
logs_data: Dict[str, Any] = None,
|
| 16 |
+
feedback: str = None
|
| 17 |
+
) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Save logs (placeholder for legacy compatibility).
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
scheduler: HuggingFace scheduler (not used in refactored version)
|
| 23 |
+
json_dataset_path: Path to JSON dataset
|
| 24 |
+
logs_data: Log data dictionary
|
| 25 |
+
feedback: User feedback
|
| 26 |
+
|
| 27 |
+
Note:
|
| 28 |
+
This is a placeholder function for backward compatibility.
|
| 29 |
+
In the refactored version, logging would be handled differently.
|
| 30 |
+
"""
|
| 31 |
+
if not is_logging_enabled():
|
| 32 |
+
return
|
| 33 |
+
try:
|
| 34 |
+
current_time = datetime.now().timestamp()
|
| 35 |
+
logs_data["time"] = str(current_time)
|
| 36 |
+
if feedback:
|
| 37 |
+
logs_data["feedback"] = feedback
|
| 38 |
+
logs_data["record_id"] = str(uuid4())
|
| 39 |
+
field_order = [
|
| 40 |
+
"record_id",
|
| 41 |
+
"session_id",
|
| 42 |
+
"time",
|
| 43 |
+
"session_duration_seconds",
|
| 44 |
+
"client_location",
|
| 45 |
+
"platform",
|
| 46 |
+
"system_prompt",
|
| 47 |
+
"sources",
|
| 48 |
+
"reports",
|
| 49 |
+
"subtype",
|
| 50 |
+
"year",
|
| 51 |
+
"question",
|
| 52 |
+
"retriever",
|
| 53 |
+
"endpoint_type",
|
| 54 |
+
"reader",
|
| 55 |
+
"docs",
|
| 56 |
+
"answer",
|
| 57 |
+
"feedback"
|
| 58 |
+
]
|
| 59 |
+
ordered_logs = {k: logs_data.get(k) for k in field_order if k in logs_data}
|
| 60 |
+
lock = getattr(scheduler, "lock", None)
|
| 61 |
+
if lock is None:
|
| 62 |
+
lock = Lock()
|
| 63 |
+
with lock:
|
| 64 |
+
with open(json_dataset_path, 'a') as f:
|
| 65 |
+
json.dump(ordered_logs, f)
|
| 66 |
+
f.write("\n")
|
| 67 |
+
logging.info("logging done")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logging.error(f"Error saving logs: {e}")
|
| 70 |
+
raise
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def setup_logging(log_level: str = "INFO", log_file: str = None) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Set up logging configuration.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
log_level: Logging level
|
| 79 |
+
log_file: Optional log file path
|
| 80 |
+
"""
|
| 81 |
+
if not is_logging_enabled():
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# Configure logging
|
| 85 |
+
logging.basicConfig(
|
| 86 |
+
level=getattr(logging, log_level.upper()),
|
| 87 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 88 |
+
handlers=[
|
| 89 |
+
logging.StreamHandler(),
|
| 90 |
+
logging.FileHandler(log_file) if log_file else logging.NullHandler()
|
| 91 |
+
]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def log_query_response(
|
| 96 |
+
query: str,
|
| 97 |
+
response: str,
|
| 98 |
+
metadata: Dict[str, Any] = None
|
| 99 |
+
) -> None:
|
| 100 |
+
"""
|
| 101 |
+
Log query and response for analysis.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
query: User query
|
| 105 |
+
response: System response
|
| 106 |
+
metadata: Additional metadata
|
| 107 |
+
"""
|
| 108 |
+
if not is_logging_enabled():
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
logger = logging.getLogger(__name__)
|
| 112 |
+
|
| 113 |
+
log_entry = {
|
| 114 |
+
"query": query,
|
| 115 |
+
"response_length": len(response),
|
| 116 |
+
"metadata": metadata or {}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
logger.info(f"Query processed: {log_entry}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def log_error(error: Exception, context: Dict[str, Any] = None) -> None:
|
| 123 |
+
"""
|
| 124 |
+
Log error with context.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
error: Exception that occurred
|
| 128 |
+
context: Additional context information
|
| 129 |
+
"""
|
| 130 |
+
if not is_logging_enabled():
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
logger = logging.getLogger(__name__)
|
| 134 |
+
|
| 135 |
+
error_info = {
|
| 136 |
+
"error_type": type(error).__name__,
|
| 137 |
+
"error_message": str(error),
|
| 138 |
+
"context": context or {}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
logger.error(f"Error occurred: {error_info}")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def log_performance_metrics(
|
| 145 |
+
operation: str,
|
| 146 |
+
duration: float,
|
| 147 |
+
metadata: Dict[str, Any] = None
|
| 148 |
+
) -> None:
|
| 149 |
+
"""
|
| 150 |
+
Log performance metrics.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
operation: Name of the operation
|
| 154 |
+
duration: Duration in seconds
|
| 155 |
+
metadata: Additional metadata
|
| 156 |
+
"""
|
| 157 |
+
if not is_logging_enabled():
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
logger = logging.getLogger(__name__)
|
| 161 |
+
|
| 162 |
+
metrics = {
|
| 163 |
+
"operation": operation,
|
| 164 |
+
"duration_seconds": duration,
|
| 165 |
+
"metadata": metadata or {}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
logger.info(f"Performance metrics: {metrics}")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def is_session_enabled() -> bool:
|
| 172 |
+
"""
|
| 173 |
+
Returns True if session management is enabled, False otherwise.
|
| 174 |
+
Checks environment variable ENABLE_SESSION first, then config.
|
| 175 |
+
"""
|
| 176 |
+
env = os.getenv("ENABLE_SESSION")
|
| 177 |
+
if env is not None:
|
| 178 |
+
return env.lower() in ("1", "true", "yes", "on")
|
| 179 |
+
config = load_config()
|
| 180 |
+
return config.get("features", {}).get("enable_session", True)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def is_logging_enabled() -> bool:
|
| 184 |
+
"""
|
| 185 |
+
Returns True if logging is enabled, False otherwise.
|
| 186 |
+
Checks environment variable ENABLE_LOGGING first, then config.
|
| 187 |
+
"""
|
| 188 |
+
env = os.getenv("ENABLE_LOGGING")
|
| 189 |
+
if env is not None:
|
| 190 |
+
return env.lower() in ("1", "true", "yes", "on")
|
| 191 |
+
config = load_config()
|
| 192 |
+
return config.get("features", {}).get("enable_logging", True)
|
| 193 |
+
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main pipeline orchestrator for the Audit QA system."""
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, Any, List, Optional
|
| 6 |
+
|
| 7 |
+
from langchain.docstore.document import Document
|
| 8 |
+
|
| 9 |
+
from .logging import log_error
|
| 10 |
+
from .llm.adapters import LLMRegistry
|
| 11 |
+
from .loader import chunks_to_documents
|
| 12 |
+
from .vectorstore import VectorStoreManager
|
| 13 |
+
from .retrieval.context import ContextRetriever
|
| 14 |
+
from .config.loader import get_embedding_model_for_collection
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class PipelineResult:
|
| 20 |
+
"""Result of pipeline execution."""
|
| 21 |
+
answer: str
|
| 22 |
+
sources: List[Document]
|
| 23 |
+
execution_time: float
|
| 24 |
+
metadata: Dict[str, Any]
|
| 25 |
+
query: str = "" # Add default value for query
|
| 26 |
+
|
| 27 |
+
def __post_init__(self):
|
| 28 |
+
"""Post-initialization processing."""
|
| 29 |
+
if not self.query:
|
| 30 |
+
self.query = "Unknown query"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class PipelineManager:
|
| 34 |
+
"""Main pipeline manager for the RAG system."""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: dict = None):
|
| 37 |
+
"""
|
| 38 |
+
Initialize the pipeline manager.
|
| 39 |
+
"""
|
| 40 |
+
self.config = config or {}
|
| 41 |
+
self.vectorstore_manager = None
|
| 42 |
+
self.context_retriever = None # Initialize as None
|
| 43 |
+
self.llm_client = None
|
| 44 |
+
self.report_service = None
|
| 45 |
+
self.chunks = None
|
| 46 |
+
|
| 47 |
+
# Initialize components
|
| 48 |
+
self._initialize_components()
|
| 49 |
+
|
| 50 |
+
def update_config(self, new_config: dict):
|
| 51 |
+
"""
|
| 52 |
+
Update the pipeline configuration.
|
| 53 |
+
This is useful for experiments that need different settings.
|
| 54 |
+
"""
|
| 55 |
+
if not isinstance(new_config, dict):
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
# Deep merge the new config with existing config
|
| 59 |
+
def deep_merge(base_dict, update_dict):
|
| 60 |
+
for key, value in update_dict.items():
|
| 61 |
+
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
|
| 62 |
+
deep_merge(base_dict[key], value)
|
| 63 |
+
else:
|
| 64 |
+
base_dict[key] = value
|
| 65 |
+
|
| 66 |
+
deep_merge(self.config, new_config)
|
| 67 |
+
|
| 68 |
+
# Auto-infer embedding model from collection name if not "docling"
|
| 69 |
+
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
|
| 70 |
+
if collection_name != 'docling':
|
| 71 |
+
inferred_model = get_embedding_model_for_collection(collection_name)
|
| 72 |
+
if inferred_model:
|
| 73 |
+
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
|
| 74 |
+
if 'retriever' not in self.config:
|
| 75 |
+
self.config['retriever'] = {}
|
| 76 |
+
self.config['retriever']['model'] = inferred_model
|
| 77 |
+
# Set default normalize parameter if not present
|
| 78 |
+
if 'normalize' not in self.config['retriever']:
|
| 79 |
+
self.config['retriever']['normalize'] = True
|
| 80 |
+
|
| 81 |
+
# Also update vectorstore config if it exists
|
| 82 |
+
if 'vectorstore' in self.config:
|
| 83 |
+
self.config['vectorstore']['embedding_model'] = inferred_model
|
| 84 |
+
|
| 85 |
+
print(f"🔧 CONFIG UPDATED: Pipeline config updated with experiment settings")
|
| 86 |
+
|
| 87 |
+
# Re-initialize vectorstore manager with updated config
|
| 88 |
+
self._reinitialize_vectorstore_manager()
|
| 89 |
+
|
| 90 |
+
def _reinitialize_vectorstore_manager(self):
|
| 91 |
+
"""Re-initialize vectorstore manager with current config."""
|
| 92 |
+
try:
|
| 93 |
+
self.vectorstore_manager = VectorStoreManager(self.config)
|
| 94 |
+
print("🔄 VectorStore manager re-initialized with updated config")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"❌ Error re-initializing vectorstore manager: {e}")
|
| 97 |
+
|
| 98 |
+
def _get_reranker_model_name(self) -> str:
|
| 99 |
+
"""
|
| 100 |
+
Get the reranker model name from configuration.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Reranker model name or default
|
| 104 |
+
"""
|
| 105 |
+
return (
|
| 106 |
+
self.config.get('retrieval', {}).get('reranker_model') or
|
| 107 |
+
self.config.get('ranker', {}).get('model') or
|
| 108 |
+
self.config.get('reranker_model') or
|
| 109 |
+
'BAAI/bge-reranker-v2-m3'
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def _initialize_components(self):
|
| 113 |
+
"""Initialize pipeline components."""
|
| 114 |
+
try:
|
| 115 |
+
# Load config if not provided
|
| 116 |
+
if not self.config:
|
| 117 |
+
from auditqa.config.loader import load_config
|
| 118 |
+
self.config = load_config()
|
| 119 |
+
|
| 120 |
+
# Auto-infer embedding model from collection name if not "docling"
|
| 121 |
+
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
|
| 122 |
+
if collection_name != 'docling':
|
| 123 |
+
inferred_model = get_embedding_model_for_collection(collection_name)
|
| 124 |
+
if inferred_model:
|
| 125 |
+
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
|
| 126 |
+
if 'retriever' not in self.config:
|
| 127 |
+
self.config['retriever'] = {}
|
| 128 |
+
self.config['retriever']['model'] = inferred_model
|
| 129 |
+
# Set default normalize parameter if not present
|
| 130 |
+
if 'normalize' not in self.config['retriever']:
|
| 131 |
+
self.config['retriever']['normalize'] = True
|
| 132 |
+
|
| 133 |
+
# Also update vectorstore config if it exists
|
| 134 |
+
if 'vectorstore' in self.config:
|
| 135 |
+
self.config['vectorstore']['embedding_model'] = inferred_model
|
| 136 |
+
|
| 137 |
+
self.vectorstore_manager = VectorStoreManager(self.config)
|
| 138 |
+
|
| 139 |
+
self.llm_manager = LLMRegistry()
|
| 140 |
+
|
| 141 |
+
# Try to get LLM client using the correct method
|
| 142 |
+
self.llm_client = None
|
| 143 |
+
try:
|
| 144 |
+
# Try using get_adapter method (most likely correct)
|
| 145 |
+
self.llm_client = self.llm_manager.get_adapter("openai")
|
| 146 |
+
print("✅ LLM CLIENT: Initialized using get_adapter method")
|
| 147 |
+
except Exception as e:
|
| 148 |
+
try:
|
| 149 |
+
# Try direct instantiation with config
|
| 150 |
+
from auditqa.llm.adapters import get_llm_client
|
| 151 |
+
self.llm_client = get_llm_client("openai", self.config)
|
| 152 |
+
print("✅ LLM CLIENT: Initialized using direct get_llm_client function with config")
|
| 153 |
+
except Exception as e2:
|
| 154 |
+
print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
|
| 155 |
+
# Try to create a simple LLM client directly
|
| 156 |
+
try:
|
| 157 |
+
from langchain_openai import ChatOpenAI
|
| 158 |
+
import os
|
| 159 |
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
| 160 |
+
if api_key:
|
| 161 |
+
self.llm_client = ChatOpenAI(
|
| 162 |
+
model="gpt-3.5-turbo",
|
| 163 |
+
api_key=api_key,
|
| 164 |
+
temperature=0.1,
|
| 165 |
+
max_tokens=1000
|
| 166 |
+
)
|
| 167 |
+
print("✅ LLM CLIENT: Initialized using direct ChatOpenAI")
|
| 168 |
+
else:
|
| 169 |
+
print("❌ LLM CLIENT: No API key available")
|
| 170 |
+
except Exception as e3:
|
| 171 |
+
print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
|
| 172 |
+
self.llm_client = None
|
| 173 |
+
|
| 174 |
+
# Load system prompt
|
| 175 |
+
from auditqa.llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 176 |
+
self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 177 |
+
|
| 178 |
+
# Initialize report service
|
| 179 |
+
try:
|
| 180 |
+
from auditqa.reporting.service import ReportService
|
| 181 |
+
self.report_service = ReportService()
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Warning: Could not initialize report service: {e}")
|
| 184 |
+
self.report_service = None
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"Warning: Error initializing components: {e}")
|
| 188 |
+
|
| 189 |
+
def test_retrieval(
|
| 190 |
+
self,
|
| 191 |
+
query: str,
|
| 192 |
+
reports: List[str] = None,
|
| 193 |
+
sources: str = None,
|
| 194 |
+
subtype: List[str] = None,
|
| 195 |
+
k: int = None,
|
| 196 |
+
search_mode: str = None,
|
| 197 |
+
search_alpha: float = None,
|
| 198 |
+
use_reranking: bool = True
|
| 199 |
+
) -> Dict[str, Any]:
|
| 200 |
+
"""
|
| 201 |
+
Test retrieval only without LLM inference.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
query: User query
|
| 205 |
+
reports: List of specific report filenames
|
| 206 |
+
sources: Source category
|
| 207 |
+
subtype: List of subtypes
|
| 208 |
+
k: Number of documents to retrieve
|
| 209 |
+
search_mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
|
| 210 |
+
search_alpha: Weight for vector scores in hybrid mode
|
| 211 |
+
use_reranking: Whether to use reranking
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Dictionary with retrieval results and metadata
|
| 215 |
+
"""
|
| 216 |
+
start_time = time.time()
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
# Set default search parameters if not provided
|
| 220 |
+
if search_mode is None:
|
| 221 |
+
search_mode = self.config.get("hybrid", {}).get("default_mode", "vector_only")
|
| 222 |
+
if search_alpha is None:
|
| 223 |
+
search_alpha = self.config.get("hybrid", {}).get("default_alpha", 0.5)
|
| 224 |
+
|
| 225 |
+
# Get vector store
|
| 226 |
+
vectorstore = self.vectorstore_manager.get_vectorstore()
|
| 227 |
+
if not vectorstore:
|
| 228 |
+
raise ValueError(
|
| 229 |
+
"Vector store not available. Call connect_vectorstore() or create_vectorstore() first."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Retrieve context with scores for test retrieval
|
| 233 |
+
context_docs_with_scores = self.context_retriever.retrieve_with_scores(
|
| 234 |
+
vectorstore=vectorstore,
|
| 235 |
+
query=query,
|
| 236 |
+
reports=reports,
|
| 237 |
+
sources=sources,
|
| 238 |
+
subtype=subtype,
|
| 239 |
+
k=k,
|
| 240 |
+
search_mode=search_mode,
|
| 241 |
+
alpha=search_alpha,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Extract documents and scores
|
| 245 |
+
context_docs = [doc for doc, score in context_docs_with_scores]
|
| 246 |
+
context_scores = [score for doc, score in context_docs_with_scores]
|
| 247 |
+
|
| 248 |
+
execution_time = time.time() - start_time
|
| 249 |
+
|
| 250 |
+
# Format results with actual scores
|
| 251 |
+
results = []
|
| 252 |
+
for i, (doc, score) in enumerate(zip(context_docs, context_scores)):
|
| 253 |
+
results.append({
|
| 254 |
+
"rank": i + 1,
|
| 255 |
+
"content": doc.page_content, # Return full content without truncation
|
| 256 |
+
"metadata": doc.metadata,
|
| 257 |
+
"score": score if score is not None else 0.0
|
| 258 |
+
})
|
| 259 |
+
|
| 260 |
+
return {
|
| 261 |
+
"results": results,
|
| 262 |
+
"num_results": len(results),
|
| 263 |
+
"execution_time": execution_time,
|
| 264 |
+
"search_mode": search_mode,
|
| 265 |
+
"search_alpha": search_alpha,
|
| 266 |
+
"query": query
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
print(f"❌ Error during retrieval test: {e}")
|
| 271 |
+
log_error(e, {"component": "retrieval_test", "query": query})
|
| 272 |
+
return {
|
| 273 |
+
"results": [],
|
| 274 |
+
"num_results": 0,
|
| 275 |
+
"execution_time": time.time() - start_time,
|
| 276 |
+
"error": str(e),
|
| 277 |
+
"search_mode": search_mode or "unknown",
|
| 278 |
+
"search_alpha": search_alpha or 0.5,
|
| 279 |
+
"query": query
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
def connect_vectorstore(self, force_recreate: bool = False) -> bool:
|
| 283 |
+
"""
|
| 284 |
+
Connect to existing vector store.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
force_recreate: If True, recreate the collection if dimension mismatch occurs
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
True if successful, False otherwise
|
| 291 |
+
"""
|
| 292 |
+
try:
|
| 293 |
+
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=force_recreate)
|
| 294 |
+
if vectorstore:
|
| 295 |
+
print("✅ Connected to vector store")
|
| 296 |
+
return True
|
| 297 |
+
else:
|
| 298 |
+
print("❌ Failed to connect to vector store")
|
| 299 |
+
return False
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f"❌ Error connecting to vector store: {e}")
|
| 302 |
+
log_error(e, {"component": "vectorstore_connection"})
|
| 303 |
+
|
| 304 |
+
# If it's a dimension mismatch error, try with force_recreate
|
| 305 |
+
if "dimensions" in str(e).lower() and not force_recreate:
|
| 306 |
+
print("🔄 Dimension mismatch detected, attempting to recreate collection...")
|
| 307 |
+
try:
|
| 308 |
+
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=True)
|
| 309 |
+
if vectorstore:
|
| 310 |
+
print("✅ Connected to vector store (recreated)")
|
| 311 |
+
return True
|
| 312 |
+
except Exception as recreate_error:
|
| 313 |
+
print(f"❌ Failed to recreate vector store: {recreate_error}")
|
| 314 |
+
log_error(recreate_error, {"component": "vectorstore_recreation"})
|
| 315 |
+
|
| 316 |
+
return False
|
| 317 |
+
|
| 318 |
+
def create_vectorstore(self) -> bool:
|
| 319 |
+
"""
|
| 320 |
+
Create new vector store from chunks.
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
True if successful, False otherwise
|
| 324 |
+
"""
|
| 325 |
+
try:
|
| 326 |
+
if not self.chunks:
|
| 327 |
+
raise ValueError("No chunks available for vector store creation")
|
| 328 |
+
|
| 329 |
+
documents = chunks_to_documents(self.chunks)
|
| 330 |
+
self.vectorstore_manager.create_from_documents(documents)
|
| 331 |
+
print("✅ Vector store created successfully")
|
| 332 |
+
return True
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"❌ Error creating vector store: {e}")
|
| 335 |
+
log_error(e, {"component": "vectorstore_creation"})
|
| 336 |
+
return False
|
| 337 |
+
|
| 338 |
+
def create_audit_prompt(self, query: str, context_docs: List[Document]) -> str:
|
| 339 |
+
"""Create a prompt for the LLM to generate an answer."""
|
| 340 |
+
try:
|
| 341 |
+
# Ensure query is not None
|
| 342 |
+
if not query or not isinstance(query, str) or query.strip() == "":
|
| 343 |
+
return "Error: No query provided"
|
| 344 |
+
|
| 345 |
+
# Ensure context_docs is not None and is a list
|
| 346 |
+
if context_docs is None:
|
| 347 |
+
context_docs = []
|
| 348 |
+
|
| 349 |
+
# Filter out None documents and ensure they have content
|
| 350 |
+
valid_docs = []
|
| 351 |
+
for doc in context_docs:
|
| 352 |
+
if doc is not None:
|
| 353 |
+
if hasattr(doc, 'page_content') and doc.page_content and isinstance(doc.page_content, str):
|
| 354 |
+
valid_docs.append(doc)
|
| 355 |
+
elif isinstance(doc, str) and doc.strip():
|
| 356 |
+
valid_docs.append(doc)
|
| 357 |
+
|
| 358 |
+
# Create context string
|
| 359 |
+
if valid_docs:
|
| 360 |
+
context_parts = []
|
| 361 |
+
for i, doc in enumerate(valid_docs, 1):
|
| 362 |
+
if hasattr(doc, 'page_content') and doc.page_content:
|
| 363 |
+
context_parts.append(f"Doc {i}: {doc.page_content}")
|
| 364 |
+
elif isinstance(doc, str) and doc.strip():
|
| 365 |
+
context_parts.append(f"Doc {i}: {doc}")
|
| 366 |
+
|
| 367 |
+
context_string = "\n\n".join(context_parts)
|
| 368 |
+
else:
|
| 369 |
+
context_string = "No relevant context found."
|
| 370 |
+
|
| 371 |
+
# Create the prompt
|
| 372 |
+
prompt = f"""
|
| 373 |
+
{self.system_prompt}
|
| 374 |
+
|
| 375 |
+
Context:
|
| 376 |
+
{context_string}
|
| 377 |
+
|
| 378 |
+
Query: {query}
|
| 379 |
+
|
| 380 |
+
Answer:"""
|
| 381 |
+
|
| 382 |
+
return prompt
|
| 383 |
+
|
| 384 |
+
except Exception as e:
|
| 385 |
+
print(f"Error creating audit prompt: {e}")
|
| 386 |
+
return f"Error creating prompt: {e}"
|
| 387 |
+
|
| 388 |
+
def _generate_answer(self, prompt: str) -> str:
|
| 389 |
+
"""Generate answer using the LLM."""
|
| 390 |
+
try:
|
| 391 |
+
if not prompt or not isinstance(prompt, str) or prompt.strip() == "":
|
| 392 |
+
return "Error: No prompt provided"
|
| 393 |
+
|
| 394 |
+
# Ensure LLM client is available
|
| 395 |
+
if not self.llm_client:
|
| 396 |
+
return "Error: LLM client not available"
|
| 397 |
+
|
| 398 |
+
# Generate response using the correct method
|
| 399 |
+
if hasattr(self.llm_client, 'generate'):
|
| 400 |
+
# Use the generate method (for adapters)
|
| 401 |
+
response = self.llm_client.generate([{"role": "user", "content": prompt}])
|
| 402 |
+
|
| 403 |
+
# Extract content from LLMResponse
|
| 404 |
+
if hasattr(response, 'content'):
|
| 405 |
+
answer = response.content
|
| 406 |
+
else:
|
| 407 |
+
answer = str(response)
|
| 408 |
+
|
| 409 |
+
elif hasattr(self.llm_client, 'invoke'):
|
| 410 |
+
# Use the invoke method (for direct LangChain models)
|
| 411 |
+
response = self.llm_client.invoke(prompt)
|
| 412 |
+
|
| 413 |
+
# Extract content safely
|
| 414 |
+
if hasattr(response, 'content') and response.content is not None:
|
| 415 |
+
answer = response.content
|
| 416 |
+
elif isinstance(response, str) and response.strip():
|
| 417 |
+
answer = response
|
| 418 |
+
else:
|
| 419 |
+
answer = str(response) if response is not None else "Error: LLM returned None response"
|
| 420 |
+
else:
|
| 421 |
+
return "Error: LLM client has no generate or invoke method"
|
| 422 |
+
|
| 423 |
+
# Ensure answer is not None and is a string
|
| 424 |
+
if answer is None or not isinstance(answer, str):
|
| 425 |
+
return "Error: LLM returned invalid response"
|
| 426 |
+
|
| 427 |
+
return answer.strip()
|
| 428 |
+
|
| 429 |
+
except Exception as e:
|
| 430 |
+
print(f"Error generating answer: {e}")
|
| 431 |
+
return f"Error generating answer: {e}"
|
| 432 |
+
|
| 433 |
+
def run(
|
| 434 |
+
self,
|
| 435 |
+
query: str,
|
| 436 |
+
reports: List[str] = None,
|
| 437 |
+
sources: List[str] = None,
|
| 438 |
+
subtype: List[str] = None,
|
| 439 |
+
llm_provider: str = None,
|
| 440 |
+
use_reranking: bool = True,
|
| 441 |
+
search_mode: str = None,
|
| 442 |
+
search_alpha: float = None,
|
| 443 |
+
auto_infer_filters: bool = True,
|
| 444 |
+
filters: Dict[str, Any] = None,
|
| 445 |
+
) -> PipelineResult:
|
| 446 |
+
"""
|
| 447 |
+
Run the complete RAG pipeline.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
query: User query
|
| 451 |
+
reports: List of specific report filenames
|
| 452 |
+
sources: Source category filter
|
| 453 |
+
subtype: List of subtypes/filenames
|
| 454 |
+
llm_provider: LLM provider to use
|
| 455 |
+
use_reranking: Whether to use reranking
|
| 456 |
+
search_mode: Search mode (vector, sparse, hybrid)
|
| 457 |
+
search_alpha: Alpha value for hybrid search
|
| 458 |
+
auto_infer_filters: Whether to auto-infer filters from query
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
PipelineResult object
|
| 462 |
+
"""
|
| 463 |
+
try:
|
| 464 |
+
# Validate input
|
| 465 |
+
if not query or not isinstance(query, str) or query.strip() == "":
|
| 466 |
+
return PipelineResult(
|
| 467 |
+
answer="Error: Invalid query provided",
|
| 468 |
+
sources=[],
|
| 469 |
+
execution_time=0.0,
|
| 470 |
+
metadata={'error': 'Invalid query'},
|
| 471 |
+
query=query
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Ensure lists are not None
|
| 475 |
+
if reports is None:
|
| 476 |
+
reports = []
|
| 477 |
+
if subtype is None:
|
| 478 |
+
subtype = []
|
| 479 |
+
|
| 480 |
+
start_time = time.time()
|
| 481 |
+
|
| 482 |
+
# Auto-infer filters if enabled and no explicit filters provided
|
| 483 |
+
inferred_filters = {}
|
| 484 |
+
filters_applied = False
|
| 485 |
+
qdrant_filter = None # Add this
|
| 486 |
+
|
| 487 |
+
if auto_infer_filters and not any([reports, sources, subtype]):
|
| 488 |
+
print(f"🤖 AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
|
| 489 |
+
try:
|
| 490 |
+
# Import get_available_metadata here to avoid circular imports
|
| 491 |
+
from auditqa.retrieval.filter import get_available_metadata, infer_filters_from_query
|
| 492 |
+
|
| 493 |
+
# Get available metadata
|
| 494 |
+
available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
|
| 495 |
+
|
| 496 |
+
# Infer filters from query - this returns a Qdrant filter
|
| 497 |
+
qdrant_filter, filter_summary = infer_filters_from_query(
|
| 498 |
+
query=query,
|
| 499 |
+
available_metadata=available_metadata,
|
| 500 |
+
llm_client=self.llm_client
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
if qdrant_filter:
|
| 504 |
+
print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
|
| 505 |
+
filters_applied = True
|
| 506 |
+
# Don't set sources/reports/subtype - use the Qdrant filter directly
|
| 507 |
+
else:
|
| 508 |
+
print(f"⚠️ NO QDRANT FILTER: Could not build Qdrant filter from query")
|
| 509 |
+
|
| 510 |
+
except Exception as e:
|
| 511 |
+
print(f"❌ AUTO-INFERENCE FAILED: {e}")
|
| 512 |
+
qdrant_filter = None
|
| 513 |
+
else:
|
| 514 |
+
# Check if any explicit filters were provided
|
| 515 |
+
filters_applied = any([reports, sources, subtype])
|
| 516 |
+
if filters_applied:
|
| 517 |
+
print(f"✅ EXPLICIT FILTERS: Using provided filters")
|
| 518 |
+
else:
|
| 519 |
+
print(f"⚠️ NO FILTERS: No explicit filters and auto-inference disabled")
|
| 520 |
+
|
| 521 |
+
# Extract filter parameters from the filters parameter
|
| 522 |
+
reports = filters.get('reports', []) if filters else []
|
| 523 |
+
sources = filters.get('sources', []) if filters else []
|
| 524 |
+
subtype = filters.get('subtype', []) if filters else []
|
| 525 |
+
year = filters.get('year', []) if filters else []
|
| 526 |
+
district = filters.get('district', []) if filters else []
|
| 527 |
+
filenames = filters.get('filenames', []) if filters else [] # Support mutually exclusive filename filtering
|
| 528 |
+
|
| 529 |
+
# Get vectorstore
|
| 530 |
+
vectorstore = self.vectorstore_manager.get_vectorstore()
|
| 531 |
+
if not vectorstore:
|
| 532 |
+
return PipelineResult(
|
| 533 |
+
answer="Error: Vector store not available",
|
| 534 |
+
sources=[],
|
| 535 |
+
execution_time=0.0,
|
| 536 |
+
metadata={'error': 'Vector store not available'},
|
| 537 |
+
query=query
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# Initialize context retriever if not already done
|
| 541 |
+
if not hasattr(self, 'context_retriever') or self.context_retriever is None:
|
| 542 |
+
# Get the actual vectorstore object
|
| 543 |
+
vectorstore_obj = self.vectorstore_manager.get_vectorstore()
|
| 544 |
+
if vectorstore_obj is None:
|
| 545 |
+
print("❌ ERROR: Vectorstore is None, cannot initialize ContextRetriever")
|
| 546 |
+
return None
|
| 547 |
+
self.context_retriever = ContextRetriever(vectorstore_obj, self.config)
|
| 548 |
+
print("✅ ContextRetriever initialized successfully")
|
| 549 |
+
|
| 550 |
+
# Debug config access
|
| 551 |
+
print(f" CONFIG DEBUG: Full config keys: {list(self.config.keys()) if isinstance(self.config, dict) else 'Not a dict'}")
|
| 552 |
+
print(f"🔍 CONFIG DEBUG: Retriever config: {self.config.get('retriever', {})}")
|
| 553 |
+
print(f"🔍 CONFIG DEBUG: Retrieval config: {self.config.get('retrieval', {})}")
|
| 554 |
+
print(f"🔍 CONFIG DEBUG: use_reranking from config: {self.config.get('retrieval', {}).get('use_reranking', 'NOT_FOUND')}")
|
| 555 |
+
|
| 556 |
+
# Get the correct top_k value
|
| 557 |
+
# Priority: experiment config > retriever config > default
|
| 558 |
+
top_k = (
|
| 559 |
+
self.config.get('retrieval', {}).get('top_k') or
|
| 560 |
+
self.config.get('retriever', {}).get('top_k') or
|
| 561 |
+
5
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Get reranking setting
|
| 565 |
+
use_reranking = self.config.get('retrieval', {}).get('use_reranking', False)
|
| 566 |
+
|
| 567 |
+
print(f"🔍 CONFIG DEBUG: Final top_k: {top_k}")
|
| 568 |
+
print(f"🔍 CONFIG DEBUG: Final use_reranking: {use_reranking}")
|
| 569 |
+
|
| 570 |
+
# Retrieve context using the context retriever
|
| 571 |
+
context_docs = self.context_retriever.retrieve_context(
|
| 572 |
+
query=query,
|
| 573 |
+
k=top_k,
|
| 574 |
+
reports=reports,
|
| 575 |
+
sources=sources,
|
| 576 |
+
subtype=subtype,
|
| 577 |
+
year=year,
|
| 578 |
+
district=district,
|
| 579 |
+
filenames=filenames,
|
| 580 |
+
use_reranking=use_reranking,
|
| 581 |
+
qdrant_filter=qdrant_filter
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Ensure context_docs is not None
|
| 585 |
+
if context_docs is None:
|
| 586 |
+
context_docs = []
|
| 587 |
+
|
| 588 |
+
# Generate answer
|
| 589 |
+
answer = self._generate_answer(self.create_audit_prompt(query, context_docs))
|
| 590 |
+
|
| 591 |
+
execution_time = time.time() - start_time
|
| 592 |
+
|
| 593 |
+
# Create result with comprehensive metadata
|
| 594 |
+
result = PipelineResult(
|
| 595 |
+
answer=answer,
|
| 596 |
+
sources=context_docs,
|
| 597 |
+
execution_time=execution_time,
|
| 598 |
+
metadata={
|
| 599 |
+
'llm_provider': llm_provider,
|
| 600 |
+
'use_reranking': use_reranking,
|
| 601 |
+
'search_mode': search_mode,
|
| 602 |
+
'search_alpha': search_alpha,
|
| 603 |
+
'auto_infer_filters': auto_infer_filters,
|
| 604 |
+
'filters_applied': filters_applied,
|
| 605 |
+
'with_filtering': filters_applied,
|
| 606 |
+
'filter_conditions': {
|
| 607 |
+
'reports': reports,
|
| 608 |
+
'sources': sources,
|
| 609 |
+
'subtype': subtype
|
| 610 |
+
},
|
| 611 |
+
'inferred_filters': inferred_filters,
|
| 612 |
+
'applied_filters': {
|
| 613 |
+
'reports': reports,
|
| 614 |
+
'sources': sources,
|
| 615 |
+
'subtype': subtype
|
| 616 |
+
},
|
| 617 |
+
# Store filter and reranking metadata
|
| 618 |
+
'filter_details': {
|
| 619 |
+
'explicit_filters': {
|
| 620 |
+
'reports': reports,
|
| 621 |
+
'sources': sources,
|
| 622 |
+
'subtype': subtype,
|
| 623 |
+
'year': year
|
| 624 |
+
},
|
| 625 |
+
'inferred_filters': inferred_filters if auto_infer_filters else {},
|
| 626 |
+
'auto_inference_enabled': auto_infer_filters,
|
| 627 |
+
'qdrant_filter_applied': qdrant_filter is not None,
|
| 628 |
+
'filter_summary': filter_summary if 'filter_summary' in locals() else None
|
| 629 |
+
},
|
| 630 |
+
'reranker_model': self._get_reranker_model_name() if use_reranking else None,
|
| 631 |
+
'reranker_applied': use_reranking,
|
| 632 |
+
'reranking_info': {
|
| 633 |
+
'model': self._get_reranker_model_name(),
|
| 634 |
+
'applied': use_reranking,
|
| 635 |
+
'top_k': len(context_docs) if context_docs else 0,
|
| 636 |
+
# 'original_documents': [
|
| 637 |
+
# {
|
| 638 |
+
# 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
|
| 639 |
+
# 'metadata': doc.metadata,
|
| 640 |
+
# 'score': getattr(doc, 'score', getattr(doc, 'original_score', 0.0))
|
| 641 |
+
# } for doc in context_docs
|
| 642 |
+
# ] if use_reranking else None,
|
| 643 |
+
'reranked_documents': [
|
| 644 |
+
{
|
| 645 |
+
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
|
| 646 |
+
'metadata': doc.metadata,
|
| 647 |
+
'score': doc.metadata.get('original_score', getattr(doc, 'score', 0.0)),
|
| 648 |
+
'original_rank': doc.metadata.get('original_rank', None),
|
| 649 |
+
'final_rank': doc.metadata.get('final_rank', None),
|
| 650 |
+
'reranked_score': doc.metadata.get('reranked_score', None)
|
| 651 |
+
} for doc in context_docs
|
| 652 |
+
] if use_reranking else None
|
| 653 |
+
}
|
| 654 |
+
},
|
| 655 |
+
query=query
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
return result
|
| 659 |
+
|
| 660 |
+
except Exception as e:
|
| 661 |
+
print(f"Error in pipeline run: {e}")
|
| 662 |
+
return PipelineResult(
|
| 663 |
+
answer=f"Error processing query: {e}",
|
| 664 |
+
sources=[],
|
| 665 |
+
execution_time=0.0,
|
| 666 |
+
metadata={'error': str(e)},
|
| 667 |
+
query=query
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def get_system_status(self) -> Dict[str, Any]:
|
| 673 |
+
"""
|
| 674 |
+
Get system status information.
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
Dictionary with system status
|
| 678 |
+
"""
|
| 679 |
+
status = {
|
| 680 |
+
"config_loaded": bool(self.config),
|
| 681 |
+
"chunks_loaded": bool(self.chunks),
|
| 682 |
+
"vectorstore_connected": bool(
|
| 683 |
+
self.vectorstore_manager and self.vectorstore_manager.get_vectorstore()
|
| 684 |
+
),
|
| 685 |
+
"components_initialized": bool(
|
| 686 |
+
self.context_retriever and self.report_service
|
| 687 |
+
),
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
if self.chunks:
|
| 691 |
+
status["num_chunks"] = len(self.chunks)
|
| 692 |
+
|
| 693 |
+
if self.report_service:
|
| 694 |
+
status["available_sources"] = self.report_service.get_available_sources()
|
| 695 |
+
status["available_reports"] = len(
|
| 696 |
+
self.report_service.get_available_reports()
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
status["overall_status"] = (
|
| 700 |
+
"ready"
|
| 701 |
+
if all(
|
| 702 |
+
[
|
| 703 |
+
status["config_loaded"],
|
| 704 |
+
status["chunks_loaded"],
|
| 705 |
+
status["vectorstore_connected"],
|
| 706 |
+
status["components_initialized"],
|
| 707 |
+
]
|
| 708 |
+
)
|
| 709 |
+
else "not_ready"
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
return status
|
| 713 |
+
|
| 714 |
+
def get_available_llm_providers(self) -> List[str]:
|
| 715 |
+
"""Get list of available LLM providers."""
|
| 716 |
+
providers = []
|
| 717 |
+
reader_config = self.config.get("reader", {})
|
| 718 |
+
|
| 719 |
+
for provider in [
|
| 720 |
+
"MISTRAL",
|
| 721 |
+
"OPENAI",
|
| 722 |
+
"OLLAMA",
|
| 723 |
+
"INF_PROVIDERS",
|
| 724 |
+
"NVIDIA",
|
| 725 |
+
"DEDICATED",
|
| 726 |
+
"OPENROUTER",
|
| 727 |
+
]:
|
| 728 |
+
if provider in reader_config:
|
| 729 |
+
providers.append(provider.lower())
|
| 730 |
+
|
| 731 |
+
return providers
|
src/reporting/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Report metadata and utilities."""
|
| 2 |
+
|
| 3 |
+
from .metadata import get_report_metadata, get_available_sources
|
| 4 |
+
from .service import ReportService
|
| 5 |
+
|
| 6 |
+
__all__ = ["get_report_metadata", "get_available_sources", "ReportService"]
|
src/reporting/feedback_schema.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Schema for RAG Chatbot
|
| 3 |
+
|
| 4 |
+
This module defines dataclasses for feedback data structures
|
| 5 |
+
and provides Snowflake schema generation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, asdict, field
|
| 9 |
+
from typing import List, Optional, Dict, Any, Union
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class RetrievedDocument:
|
| 15 |
+
"""Single retrieved document metadata"""
|
| 16 |
+
doc_id: str
|
| 17 |
+
filename: str
|
| 18 |
+
page: int
|
| 19 |
+
score: float
|
| 20 |
+
content: str
|
| 21 |
+
metadata: Dict[str, Any]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class RetrievalEntry:
|
| 26 |
+
"""Single retrieval operation metadata"""
|
| 27 |
+
rag_query: str
|
| 28 |
+
documents_retrieved: List[RetrievedDocument]
|
| 29 |
+
conversation_length: int
|
| 30 |
+
filters_applied: Optional[Dict[str, Any]] = None
|
| 31 |
+
timestamp: Optional[float] = None
|
| 32 |
+
_raw_data: Optional[Dict[str, Any]] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class UserFeedback:
|
| 37 |
+
"""User feedback submission data"""
|
| 38 |
+
feedback_id: str
|
| 39 |
+
open_ended_feedback: Optional[str]
|
| 40 |
+
score: int
|
| 41 |
+
is_feedback_about_last_retrieval: bool
|
| 42 |
+
retrieved_data: List[RetrievalEntry]
|
| 43 |
+
conversation_id: str
|
| 44 |
+
timestamp: float
|
| 45 |
+
message_count: int
|
| 46 |
+
has_retrievals: bool
|
| 47 |
+
retrieval_count: int
|
| 48 |
+
user_query: Optional[str] = None
|
| 49 |
+
bot_response: Optional[str] = None
|
| 50 |
+
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 51 |
+
|
| 52 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 53 |
+
"""Convert to dictionary with nested data structures"""
|
| 54 |
+
result = asdict(self)
|
| 55 |
+
# Handle nested objects
|
| 56 |
+
if self.retrieved_data:
|
| 57 |
+
result['retrieved_data'] = [self._serialize_retrieval_entry(entry) for entry in self.retrieved_data]
|
| 58 |
+
return result
|
| 59 |
+
|
| 60 |
+
def _serialize_retrieval_entry(self, entry: RetrievalEntry) -> Dict[str, Any]:
|
| 61 |
+
"""Serialize retrieval entry to dict"""
|
| 62 |
+
# If raw data exists, use it (it's already properly formatted)
|
| 63 |
+
if hasattr(entry, '_raw_data') and entry._raw_data:
|
| 64 |
+
return entry._raw_data
|
| 65 |
+
|
| 66 |
+
# Otherwise, serialize the dataclass
|
| 67 |
+
result = asdict(entry)
|
| 68 |
+
if entry.documents_retrieved:
|
| 69 |
+
result['documents_retrieved'] = [asdict(doc) for doc in entry.documents_retrieved]
|
| 70 |
+
return result
|
| 71 |
+
|
| 72 |
+
def to_snowflake_schema(self) -> Dict[str, Any]:
|
| 73 |
+
"""Generate Snowflake schema for this dataclass"""
|
| 74 |
+
schema = {
|
| 75 |
+
"feedback_id": "VARCHAR(255)",
|
| 76 |
+
"open_ended_feedback": "VARCHAR(16777216)", # Large text
|
| 77 |
+
"score": "INTEGER",
|
| 78 |
+
"is_feedback_about_last_retrieval": "BOOLEAN",
|
| 79 |
+
"conversation_id": "VARCHAR(255)",
|
| 80 |
+
"timestamp": "NUMBER(20, 0)",
|
| 81 |
+
"message_count": "INTEGER",
|
| 82 |
+
"has_retrievals": "BOOLEAN",
|
| 83 |
+
"retrieval_count": "INTEGER",
|
| 84 |
+
"user_query": "VARCHAR(16777216)",
|
| 85 |
+
"bot_response": "VARCHAR(16777216)",
|
| 86 |
+
"created_at": "TIMESTAMP_NTZ",
|
| 87 |
+
"retrieved_data": "VARIANT", # Array of retrieval entries
|
| 88 |
+
# retrieved_data structure:
|
| 89 |
+
# [
|
| 90 |
+
# {
|
| 91 |
+
# "rag_query": "...",
|
| 92 |
+
# "conversation_length": 5,
|
| 93 |
+
# "timestamp": 1234567890,
|
| 94 |
+
# "docs_retrieved": [
|
| 95 |
+
# {"filename": "...", "page": 14, "score": 0.95, ...},
|
| 96 |
+
# ...
|
| 97 |
+
# ]
|
| 98 |
+
# },
|
| 99 |
+
# ...
|
| 100 |
+
# ]
|
| 101 |
+
}
|
| 102 |
+
return schema
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def get_snowflake_create_table_sql(cls, table_name: str = "user_feedback") -> str:
|
| 106 |
+
"""Generate CREATE TABLE SQL for Snowflake"""
|
| 107 |
+
schema = cls.to_snowflake_schema(None)
|
| 108 |
+
|
| 109 |
+
columns = []
|
| 110 |
+
for col_name, col_type in schema.items():
|
| 111 |
+
nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
|
| 112 |
+
columns.append(f" {col_name} {col_type} {nullable}")
|
| 113 |
+
|
| 114 |
+
# Build SQL string properly
|
| 115 |
+
columns_str = ",\n".join(columns)
|
| 116 |
+
|
| 117 |
+
sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
|
| 118 |
+
{columns_str},
|
| 119 |
+
PRIMARY KEY (feedback_id)
|
| 120 |
+
);
|
| 121 |
+
|
| 122 |
+
-- Create index on timestamp for querying by time
|
| 123 |
+
CREATE INDEX IF NOT EXISTS idx_feedback_timestamp ON {table_name} (timestamp);
|
| 124 |
+
|
| 125 |
+
-- Create index on conversation_id for querying by conversation
|
| 126 |
+
CREATE INDEX IF NOT EXISTS idx_feedback_conversation ON {table_name} (conversation_id);
|
| 127 |
+
|
| 128 |
+
-- Create index on score for feedback analysis
|
| 129 |
+
CREATE INDEX IF NOT EXISTS idx_feedback_score ON {table_name} (score);
|
| 130 |
+
"""
|
| 131 |
+
return sql
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Snowflake variant schema for retrieved_data array
|
| 135 |
+
RETRIEVAL_ENTRY_SCHEMA = {
|
| 136 |
+
"rag_query": "VARCHAR",
|
| 137 |
+
"documents_retrieved": "ARRAY", # Array of document objects
|
| 138 |
+
"conversation_length": "INTEGER",
|
| 139 |
+
"filters_applied": "OBJECT",
|
| 140 |
+
"timestamp": "NUMBER"
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
DOCUMENT_SCHEMA = {
|
| 144 |
+
"doc_id": "VARCHAR",
|
| 145 |
+
"filename": "VARCHAR",
|
| 146 |
+
"page": "INTEGER",
|
| 147 |
+
"score": "DOUBLE",
|
| 148 |
+
"content": "VARCHAR(16777216)",
|
| 149 |
+
"metadata": "OBJECT"
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def generate_snowflake_schema_sql() -> str:
|
| 154 |
+
"""Generate complete Snowflake schema SQL for feedback system"""
|
| 155 |
+
return UserFeedback.get_snowflake_create_table_sql("user_feedback")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 159 |
+
"""Create UserFeedback instance from dictionary"""
|
| 160 |
+
# Parse retrieved_data if present
|
| 161 |
+
retrieved_data = []
|
| 162 |
+
if "retrieved_data" in data and data["retrieved_data"]:
|
| 163 |
+
for entry_dict in data.get("retrieved_data", []):
|
| 164 |
+
# Map the actual structure from rag_retrieval_history
|
| 165 |
+
# Entry has: conversation_up_to, rag_query_expansion, docs_retrieved
|
| 166 |
+
try:
|
| 167 |
+
# Try to map to expected structure
|
| 168 |
+
entry = RetrievalEntry(
|
| 169 |
+
rag_query=entry_dict.get("rag_query_expansion", ""),
|
| 170 |
+
documents_retrieved=[], # Empty for now, will store as raw data
|
| 171 |
+
conversation_length=len(entry_dict.get("conversation_up_to", [])),
|
| 172 |
+
filters_applied=None,
|
| 173 |
+
timestamp=entry_dict.get("timestamp", None)
|
| 174 |
+
)
|
| 175 |
+
# Store raw data in the entry
|
| 176 |
+
entry._raw_data = entry_dict # Store original for preservation
|
| 177 |
+
retrieved_data.append(entry)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
# If mapping fails, store as-is without strict typing
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
return UserFeedback(
|
| 183 |
+
feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
|
| 184 |
+
open_ended_feedback=data.get("open_ended_feedback"),
|
| 185 |
+
score=data["score"],
|
| 186 |
+
is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
|
| 187 |
+
retrieved_data=retrieved_data,
|
| 188 |
+
conversation_id=data["conversation_id"],
|
| 189 |
+
timestamp=data["timestamp"],
|
| 190 |
+
message_count=data["message_count"],
|
| 191 |
+
has_retrievals=data["has_retrievals"],
|
| 192 |
+
retrieval_count=data["retrieval_count"],
|
| 193 |
+
user_query=data.get("user_query"),
|
| 194 |
+
bot_response=data.get("bot_response")
|
| 195 |
+
)
|
| 196 |
+
|
src/reporting/metadata.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Report metadata management."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Any, Set
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_report_metadata(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 8 |
+
"""
|
| 9 |
+
Extract metadata from chunks.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
chunks: List of chunk dictionaries
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Dictionary with report metadata
|
| 16 |
+
"""
|
| 17 |
+
if not chunks:
|
| 18 |
+
return {}
|
| 19 |
+
|
| 20 |
+
sources = set()
|
| 21 |
+
filenames = set()
|
| 22 |
+
years = set()
|
| 23 |
+
|
| 24 |
+
for chunk in chunks:
|
| 25 |
+
metadata = chunk.get("metadata", {})
|
| 26 |
+
|
| 27 |
+
if "source" in metadata:
|
| 28 |
+
sources.add(metadata["source"])
|
| 29 |
+
|
| 30 |
+
if "filename" in metadata:
|
| 31 |
+
filenames.add(metadata["filename"])
|
| 32 |
+
|
| 33 |
+
if "year" in metadata:
|
| 34 |
+
years.add(metadata["year"])
|
| 35 |
+
|
| 36 |
+
return {
|
| 37 |
+
"sources": sorted(list(sources)),
|
| 38 |
+
"filenames": sorted(list(filenames)),
|
| 39 |
+
"years": sorted(list(years)),
|
| 40 |
+
"total_chunks": len(chunks)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_available_sources() -> List[str]:
|
| 45 |
+
"""
|
| 46 |
+
Get list of available report sources (legacy compatibility).
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List of source categories
|
| 50 |
+
"""
|
| 51 |
+
# This would typically come from the original auditqa_old.reports module
|
| 52 |
+
# For now, return common categories
|
| 53 |
+
return [
|
| 54 |
+
"Consolidated",
|
| 55 |
+
"Ministry, Department, Agency and Projects",
|
| 56 |
+
"Local Government",
|
| 57 |
+
"Value for Money",
|
| 58 |
+
"Thematic",
|
| 59 |
+
"Hospital",
|
| 60 |
+
"Project"
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_source_subtypes() -> Dict[str, List[str]]:
|
| 65 |
+
"""
|
| 66 |
+
Get mapping of sources to their subtypes (placeholder).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dictionary mapping sources to subtypes
|
| 70 |
+
"""
|
| 71 |
+
# This was originally imported from auditqa_old.reports.new_files
|
| 72 |
+
# For now, return a placeholder structure
|
| 73 |
+
return {
|
| 74 |
+
"Consolidated": ["Annual Consolidated OAG 2024", "Annual Consolidated OAG 2023"],
|
| 75 |
+
"Local Government": ["District Reports", "Municipal Reports"],
|
| 76 |
+
"Ministry, Department, Agency and Projects": ["Ministry Reports", "Agency Reports"],
|
| 77 |
+
"Value for Money": ["VFM Reports 2024", "VFM Reports 2023"],
|
| 78 |
+
"Thematic": ["Thematic Reports 2024", "Thematic Reports 2023"],
|
| 79 |
+
"Hospital": ["Hospital Reports 2024", "Hospital Reports 2023"],
|
| 80 |
+
"Project": ["Project Reports 2024", "Project Reports 2023"]
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def validate_report_filters(
|
| 85 |
+
reports: List[str] = None,
|
| 86 |
+
sources: str = None,
|
| 87 |
+
subtype: List[str] = None,
|
| 88 |
+
available_metadata: Dict[str, Any] = None
|
| 89 |
+
) -> Dict[str, Any]:
|
| 90 |
+
"""
|
| 91 |
+
Validate report filter parameters.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
reports: List of specific report filenames
|
| 95 |
+
sources: Source category
|
| 96 |
+
subtype: List of subtypes
|
| 97 |
+
available_metadata: Available metadata for validation
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Dictionary with validation results
|
| 101 |
+
"""
|
| 102 |
+
validation_result = {
|
| 103 |
+
"valid": True,
|
| 104 |
+
"warnings": [],
|
| 105 |
+
"errors": []
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
if not available_metadata:
|
| 109 |
+
validation_result["warnings"].append("No metadata available for validation")
|
| 110 |
+
return validation_result
|
| 111 |
+
|
| 112 |
+
available_sources = available_metadata.get("sources", [])
|
| 113 |
+
available_filenames = available_metadata.get("filenames", [])
|
| 114 |
+
|
| 115 |
+
# Validate sources
|
| 116 |
+
if sources and sources not in available_sources:
|
| 117 |
+
validation_result["errors"].append(f"Source '{sources}' not found in available sources")
|
| 118 |
+
validation_result["valid"] = False
|
| 119 |
+
|
| 120 |
+
# Validate reports
|
| 121 |
+
if reports:
|
| 122 |
+
for report in reports:
|
| 123 |
+
if report not in available_filenames:
|
| 124 |
+
validation_result["warnings"].append(f"Report '{report}' not found in available reports")
|
| 125 |
+
|
| 126 |
+
# Validate subtypes
|
| 127 |
+
if subtype:
|
| 128 |
+
for sub in subtype:
|
| 129 |
+
if sub not in available_filenames:
|
| 130 |
+
validation_result["warnings"].append(f"Subtype '{sub}' not found in available reports")
|
| 131 |
+
|
| 132 |
+
return validation_result
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_report_statistics(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 136 |
+
"""
|
| 137 |
+
Get statistics about reports in chunks.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
chunks: List of chunk dictionaries
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Dictionary with report statistics
|
| 144 |
+
"""
|
| 145 |
+
if not chunks:
|
| 146 |
+
return {}
|
| 147 |
+
|
| 148 |
+
stats = {
|
| 149 |
+
"total_chunks": len(chunks),
|
| 150 |
+
"sources": {},
|
| 151 |
+
"years": {},
|
| 152 |
+
"avg_chunk_length": 0,
|
| 153 |
+
"total_content_length": 0
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
total_length = 0
|
| 157 |
+
|
| 158 |
+
for chunk in chunks:
|
| 159 |
+
content = chunk.get("content", "")
|
| 160 |
+
total_length += len(content)
|
| 161 |
+
|
| 162 |
+
metadata = chunk.get("metadata", {})
|
| 163 |
+
|
| 164 |
+
# Count by source
|
| 165 |
+
source = metadata.get("source", "Unknown")
|
| 166 |
+
stats["sources"][source] = stats["sources"].get(source, 0) + 1
|
| 167 |
+
|
| 168 |
+
# Count by year
|
| 169 |
+
year = metadata.get("year", "Unknown")
|
| 170 |
+
stats["years"][year] = stats["years"].get(year, 0) + 1
|
| 171 |
+
|
| 172 |
+
stats["total_content_length"] = total_length
|
| 173 |
+
stats["avg_chunk_length"] = total_length / len(chunks) if chunks else 0
|
| 174 |
+
|
| 175 |
+
return stats
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def filter_chunks_by_metadata(
|
| 179 |
+
chunks: List[Dict[str, Any]],
|
| 180 |
+
source_filter: str = None,
|
| 181 |
+
filename_filter: List[str] = None,
|
| 182 |
+
year_filter: List[str] = None
|
| 183 |
+
) -> List[Dict[str, Any]]:
|
| 184 |
+
"""
|
| 185 |
+
Filter chunks by metadata criteria.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
chunks: List of chunk dictionaries
|
| 189 |
+
source_filter: Source to filter by
|
| 190 |
+
filename_filter: List of filenames to filter by
|
| 191 |
+
year_filter: List of years to filter by
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Filtered list of chunks
|
| 195 |
+
"""
|
| 196 |
+
filtered_chunks = chunks
|
| 197 |
+
|
| 198 |
+
if source_filter:
|
| 199 |
+
filtered_chunks = [
|
| 200 |
+
chunk for chunk in filtered_chunks
|
| 201 |
+
if chunk.get("metadata", {}).get("source") == source_filter
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
if filename_filter:
|
| 205 |
+
filtered_chunks = [
|
| 206 |
+
chunk for chunk in filtered_chunks
|
| 207 |
+
if chunk.get("metadata", {}).get("filename") in filename_filter
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
if year_filter:
|
| 211 |
+
filtered_chunks = [
|
| 212 |
+
chunk for chunk in filtered_chunks
|
| 213 |
+
if chunk.get("metadata", {}).get("year") in year_filter
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
return filtered_chunks
|
src/reporting/service.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Report service for managing report operations."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Any, Optional
|
| 4 |
+
from .metadata import get_report_metadata, get_available_sources, get_source_subtypes
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ReportService:
|
| 8 |
+
"""Service class for report operations."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, chunks: List[Dict[str, Any]] = None):
|
| 11 |
+
"""
|
| 12 |
+
Initialize report service.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
chunks: List of chunk dictionaries
|
| 16 |
+
"""
|
| 17 |
+
self.chunks = chunks or []
|
| 18 |
+
self.metadata = get_report_metadata(self.chunks) if self.chunks else {}
|
| 19 |
+
|
| 20 |
+
def get_available_sources(self) -> List[str]:
|
| 21 |
+
"""Get available report sources."""
|
| 22 |
+
if self.metadata:
|
| 23 |
+
return self.metadata.get("sources", [])
|
| 24 |
+
return get_available_sources()
|
| 25 |
+
|
| 26 |
+
def get_available_reports(self) -> List[str]:
|
| 27 |
+
"""Get available report filenames."""
|
| 28 |
+
return self.metadata.get("filenames", [])
|
| 29 |
+
|
| 30 |
+
def get_source_subtypes(self) -> Dict[str, List[str]]:
|
| 31 |
+
"""Get source to subtype mapping."""
|
| 32 |
+
# For now, use the placeholder function
|
| 33 |
+
# In a full implementation, this would be derived from actual data
|
| 34 |
+
return get_source_subtypes()
|
| 35 |
+
|
| 36 |
+
def get_reports_by_source(self, source: str) -> List[str]:
|
| 37 |
+
"""
|
| 38 |
+
Get reports filtered by source.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
source: Source category
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
List of report filenames
|
| 45 |
+
"""
|
| 46 |
+
if not self.chunks:
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
reports = set()
|
| 50 |
+
for chunk in self.chunks:
|
| 51 |
+
metadata = chunk.get("metadata", {})
|
| 52 |
+
if metadata.get("source") == source:
|
| 53 |
+
filename = metadata.get("filename")
|
| 54 |
+
if filename:
|
| 55 |
+
reports.add(filename)
|
| 56 |
+
|
| 57 |
+
return sorted(list(reports))
|
| 58 |
+
|
| 59 |
+
def get_years_by_source(self, source: str) -> List[str]:
|
| 60 |
+
"""
|
| 61 |
+
Get years available for a specific source.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
source: Source category
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
List of years
|
| 68 |
+
"""
|
| 69 |
+
if not self.chunks:
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
years = set()
|
| 73 |
+
for chunk in self.chunks:
|
| 74 |
+
metadata = chunk.get("metadata", {})
|
| 75 |
+
if metadata.get("source") == source:
|
| 76 |
+
year = metadata.get("year")
|
| 77 |
+
if year:
|
| 78 |
+
years.add(year)
|
| 79 |
+
|
| 80 |
+
return sorted(list(years))
|
| 81 |
+
|
| 82 |
+
def search_reports(self, query: str) -> List[str]:
|
| 83 |
+
"""
|
| 84 |
+
Search for reports by name.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
query: Search query
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List of matching report filenames
|
| 91 |
+
"""
|
| 92 |
+
if not self.chunks:
|
| 93 |
+
return []
|
| 94 |
+
|
| 95 |
+
query_lower = query.lower()
|
| 96 |
+
matching_reports = set()
|
| 97 |
+
|
| 98 |
+
for chunk in self.chunks:
|
| 99 |
+
metadata = chunk.get("metadata", {})
|
| 100 |
+
filename = metadata.get("filename", "")
|
| 101 |
+
|
| 102 |
+
if query_lower in filename.lower():
|
| 103 |
+
matching_reports.add(filename)
|
| 104 |
+
|
| 105 |
+
return sorted(list(matching_reports))
|
| 106 |
+
|
| 107 |
+
def get_report_info(self, filename: str) -> Dict[str, Any]:
|
| 108 |
+
"""
|
| 109 |
+
Get information about a specific report.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
filename: Report filename
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dictionary with report information
|
| 116 |
+
"""
|
| 117 |
+
if not self.chunks:
|
| 118 |
+
return {}
|
| 119 |
+
|
| 120 |
+
report_info = {
|
| 121 |
+
"filename": filename,
|
| 122 |
+
"chunk_count": 0,
|
| 123 |
+
"sources": set(),
|
| 124 |
+
"years": set(),
|
| 125 |
+
"total_content_length": 0
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
for chunk in self.chunks:
|
| 129 |
+
metadata = chunk.get("metadata", {})
|
| 130 |
+
if metadata.get("filename") == filename:
|
| 131 |
+
report_info["chunk_count"] += 1
|
| 132 |
+
report_info["total_content_length"] += len(chunk.get("content", ""))
|
| 133 |
+
|
| 134 |
+
if "source" in metadata:
|
| 135 |
+
report_info["sources"].add(metadata["source"])
|
| 136 |
+
|
| 137 |
+
if "year" in metadata:
|
| 138 |
+
report_info["years"].add(metadata["year"])
|
| 139 |
+
|
| 140 |
+
# Convert sets to lists
|
| 141 |
+
report_info["sources"] = list(report_info["sources"])
|
| 142 |
+
report_info["years"] = list(report_info["years"])
|
| 143 |
+
|
| 144 |
+
return report_info
|
src/reporting/snowflake_connector.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Snowflake Connector for Feedback System
|
| 3 |
+
|
| 4 |
+
This module handles inserting user feedback into Snowflake.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
from src.reporting.feedback_schema import UserFeedback
|
| 12 |
+
|
| 13 |
+
# Try to import snowflake connector
|
| 14 |
+
try:
|
| 15 |
+
import snowflake.connector
|
| 16 |
+
SNOWFLAKE_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
SNOWFLAKE_AVAILABLE = False
|
| 19 |
+
logging.warning("⚠️ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SnowflakeFeedbackConnector:
|
| 27 |
+
"""Connector for inserting feedback into Snowflake"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
user: str,
|
| 32 |
+
password: str,
|
| 33 |
+
account: str,
|
| 34 |
+
warehouse: str,
|
| 35 |
+
database: str = "SNOWFLAKE_LEARNING",
|
| 36 |
+
schema: str = "PUBLIC"
|
| 37 |
+
):
|
| 38 |
+
self.user = user
|
| 39 |
+
self.password = password
|
| 40 |
+
self.account = account
|
| 41 |
+
self.warehouse = warehouse
|
| 42 |
+
self.database = database
|
| 43 |
+
self.schema = schema
|
| 44 |
+
self._connection = None
|
| 45 |
+
|
| 46 |
+
def connect(self):
|
| 47 |
+
"""Establish Snowflake connection"""
|
| 48 |
+
if not SNOWFLAKE_AVAILABLE:
|
| 49 |
+
raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
|
| 50 |
+
|
| 51 |
+
logger.info("=" * 80)
|
| 52 |
+
logger.info("🔌 SNOWFLAKE CONNECTION: Attempting to connect...")
|
| 53 |
+
logger.info(f" - Account: {self.account}")
|
| 54 |
+
logger.info(f" - Warehouse: {self.warehouse}")
|
| 55 |
+
logger.info(f" - Database: {self.database}")
|
| 56 |
+
logger.info(f" - Schema: {self.schema}")
|
| 57 |
+
logger.info(f" - User: {self.user}")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
self._connection = snowflake.connector.connect(
|
| 61 |
+
user=self.user,
|
| 62 |
+
password=self.password,
|
| 63 |
+
account=self.account,
|
| 64 |
+
warehouse=self.warehouse
|
| 65 |
+
# Don't set database/schema in connection - we'll do it per query
|
| 66 |
+
)
|
| 67 |
+
logger.info("✅ SNOWFLAKE CONNECTION: Successfully connected")
|
| 68 |
+
logger.info("=" * 80)
|
| 69 |
+
print(f"✅ Connected to Snowflake: {self.database}.{self.schema}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"❌ SNOWFLAKE CONNECTION FAILED: {e}")
|
| 72 |
+
logger.error("=" * 80)
|
| 73 |
+
print(f"❌ Failed to connect to Snowflake: {e}")
|
| 74 |
+
raise
|
| 75 |
+
|
| 76 |
+
def disconnect(self):
|
| 77 |
+
"""Close Snowflake connection"""
|
| 78 |
+
if self._connection:
|
| 79 |
+
self._connection.close()
|
| 80 |
+
print("✅ Disconnected from Snowflake")
|
| 81 |
+
|
| 82 |
+
def insert_feedback(self, feedback: UserFeedback) -> bool:
|
| 83 |
+
"""Insert a single feedback record into Snowflake"""
|
| 84 |
+
logger.info("=" * 80)
|
| 85 |
+
logger.info("🔄 SNOWFLAKE INSERT: Starting feedback insertion process")
|
| 86 |
+
logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
|
| 87 |
+
|
| 88 |
+
if not self._connection:
|
| 89 |
+
logger.error("❌ Not connected to Snowflake. Call connect() first.")
|
| 90 |
+
raise RuntimeError("Not connected to Snowflake. Call connect() first.")
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
logger.info("📊 VALIDATION: Validating feedback data structure...")
|
| 94 |
+
|
| 95 |
+
# Validate feedback object
|
| 96 |
+
validation_errors = []
|
| 97 |
+
if not feedback.feedback_id:
|
| 98 |
+
validation_errors.append("Missing feedback_id")
|
| 99 |
+
if feedback.score is None:
|
| 100 |
+
validation_errors.append("Missing score")
|
| 101 |
+
if feedback.timestamp is None:
|
| 102 |
+
validation_errors.append("Missing timestamp")
|
| 103 |
+
|
| 104 |
+
if validation_errors:
|
| 105 |
+
logger.error(f"❌ VALIDATION FAILED: {validation_errors}")
|
| 106 |
+
return False
|
| 107 |
+
else:
|
| 108 |
+
logger.info("✅ VALIDATION PASSED: All required fields present")
|
| 109 |
+
|
| 110 |
+
logger.info("📋 Data Summary:")
|
| 111 |
+
logger.info(f" - Feedback ID: {feedback.feedback_id}")
|
| 112 |
+
logger.info(f" - Score: {feedback.score}")
|
| 113 |
+
logger.info(f" - Conversation ID: {feedback.conversation_id}")
|
| 114 |
+
logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
|
| 115 |
+
logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
|
| 116 |
+
logger.info(f" - Message Count: {feedback.message_count}")
|
| 117 |
+
logger.info(f" - Timestamp: {feedback.timestamp}")
|
| 118 |
+
|
| 119 |
+
cursor = self._connection.cursor()
|
| 120 |
+
logger.info("✅ SNOWFLAKE CONNECTION: Cursor created")
|
| 121 |
+
|
| 122 |
+
# Set database and schema context
|
| 123 |
+
logger.info(f"🔧 SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
|
| 124 |
+
try:
|
| 125 |
+
cursor.execute(f'USE DATABASE "{self.database}"')
|
| 126 |
+
cursor.execute(f'USE SCHEMA "{self.schema}"')
|
| 127 |
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
| 128 |
+
current_db, current_schema = cursor.fetchone()
|
| 129 |
+
logger.info(f"✅ Current context verified: Database={current_db}, Schema={current_schema}")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"❌ Could not set context: {e}")
|
| 132 |
+
raise
|
| 133 |
+
|
| 134 |
+
# Prepare data
|
| 135 |
+
logger.info("🔧 DATA PREPARATION: Preparing retrieved_data...")
|
| 136 |
+
retrieved_data_raw = feedback.to_dict()['retrieved_data']
|
| 137 |
+
|
| 138 |
+
logger.info(f" - Retrieved data type (raw): {type(retrieved_data_raw).__name__}")
|
| 139 |
+
logger.info(f" - Retrieved data: {repr(retrieved_data_raw)[:200]}")
|
| 140 |
+
|
| 141 |
+
# If retrieved_data is already a string (from UI), parse it
|
| 142 |
+
if isinstance(retrieved_data_raw, str):
|
| 143 |
+
logger.info(" - Parsing string to Python object")
|
| 144 |
+
retrieved_data = json.loads(retrieved_data_raw)
|
| 145 |
+
elif retrieved_data_raw is None:
|
| 146 |
+
retrieved_data = None
|
| 147 |
+
else:
|
| 148 |
+
# It's already a Python object (list/dict)
|
| 149 |
+
logger.info(" - Data is already a Python object")
|
| 150 |
+
retrieved_data = retrieved_data_raw
|
| 151 |
+
|
| 152 |
+
logger.info(f" - Retrieved data size: {len(str(retrieved_data)) if retrieved_data else 0} characters")
|
| 153 |
+
logger.info(f" - Retrieved data type: {type(retrieved_data).__name__}")
|
| 154 |
+
|
| 155 |
+
# Convert to JSON string for TEXT column
|
| 156 |
+
if retrieved_data:
|
| 157 |
+
retrieved_data_for_db = json.dumps(retrieved_data)
|
| 158 |
+
logger.info(f" - Converting to JSON string for TEXT column")
|
| 159 |
+
logger.info(f" - JSON string length: {len(retrieved_data_for_db)}")
|
| 160 |
+
else:
|
| 161 |
+
logger.info(f" - Retrieved data is None, using NULL")
|
| 162 |
+
retrieved_data_for_db = None
|
| 163 |
+
|
| 164 |
+
# Build SQL with retrieved_data as a TEXT column parameter
|
| 165 |
+
sql = f"""INSERT INTO user_feedback (
|
| 166 |
+
feedback_id,
|
| 167 |
+
open_ended_feedback,
|
| 168 |
+
score,
|
| 169 |
+
is_feedback_about_last_retrieval,
|
| 170 |
+
conversation_id,
|
| 171 |
+
timestamp,
|
| 172 |
+
message_count,
|
| 173 |
+
has_retrievals,
|
| 174 |
+
retrieval_count,
|
| 175 |
+
user_query,
|
| 176 |
+
bot_response,
|
| 177 |
+
created_at,
|
| 178 |
+
retrieved_data
|
| 179 |
+
) VALUES (
|
| 180 |
+
%(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
|
| 181 |
+
%(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
|
| 182 |
+
%(retrieval_count)s, %(user_query)s, %(bot_response)s, %(created_at)s,
|
| 183 |
+
%(retrieved_data)s
|
| 184 |
+
)"""
|
| 185 |
+
|
| 186 |
+
logger.info("📝 SQL PREPARATION: Building INSERT statement...")
|
| 187 |
+
logger.info(f" - Target table: user_feedback")
|
| 188 |
+
logger.info(f" - Database: {self.database}")
|
| 189 |
+
logger.info(f" - Schema: {self.schema}")
|
| 190 |
+
|
| 191 |
+
# Prepare parameters
|
| 192 |
+
params = {
|
| 193 |
+
'feedback_id': feedback.feedback_id,
|
| 194 |
+
'open_ended_feedback': feedback.open_ended_feedback,
|
| 195 |
+
'score': feedback.score,
|
| 196 |
+
'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
|
| 197 |
+
'conversation_id': feedback.conversation_id,
|
| 198 |
+
'timestamp': int(feedback.timestamp),
|
| 199 |
+
'message_count': feedback.message_count,
|
| 200 |
+
'has_retrievals': feedback.has_retrievals,
|
| 201 |
+
'retrieval_count': feedback.retrieval_count,
|
| 202 |
+
'user_query': feedback.user_query,
|
| 203 |
+
'bot_response': feedback.bot_response,
|
| 204 |
+
'created_at': feedback.created_at,
|
| 205 |
+
'retrieved_data': retrieved_data_for_db
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# Execute insert
|
| 209 |
+
logger.info("🚀 SQL EXECUTION: Executing INSERT query...")
|
| 210 |
+
cursor.execute(sql, params)
|
| 211 |
+
|
| 212 |
+
logger.info("✅ SQL EXECUTION: Query executed successfully")
|
| 213 |
+
logger.info(f" - Rows affected: 1")
|
| 214 |
+
logger.info(f" - Status: SUCCESS")
|
| 215 |
+
|
| 216 |
+
cursor.close()
|
| 217 |
+
logger.info("✅ SNOWFLAKE INSERT: Feedback inserted successfully")
|
| 218 |
+
logger.info(f"📝 Inserted feedback: {feedback.feedback_id}")
|
| 219 |
+
logger.info("=" * 80)
|
| 220 |
+
return True
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
# Check if it's a Snowflake error
|
| 224 |
+
if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
|
| 225 |
+
logger.error(f"❌ SQL EXECUTION ERROR: {e}")
|
| 226 |
+
logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
|
| 227 |
+
logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
|
| 228 |
+
else:
|
| 229 |
+
logger.error(f"❌ SNOWFLAKE INSERT FAILED: {type(e).__name__}")
|
| 230 |
+
logger.error(f" - Error: {e}")
|
| 231 |
+
logger.error("=" * 80)
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
def __enter__(self):
|
| 235 |
+
"""Context manager entry"""
|
| 236 |
+
self.connect()
|
| 237 |
+
return self
|
| 238 |
+
|
| 239 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 240 |
+
"""Context manager exit"""
|
| 241 |
+
self.disconnect()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
|
| 245 |
+
"""Create Snowflake connector from environment variables"""
|
| 246 |
+
user = os.getenv("SNOWFLAKE_USER")
|
| 247 |
+
password = os.getenv("SNOWFLAKE_PASSWORD")
|
| 248 |
+
account = os.getenv("SNOWFLAKE_ACCOUNT")
|
| 249 |
+
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
|
| 250 |
+
database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
|
| 251 |
+
schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
|
| 252 |
+
|
| 253 |
+
if not all([user, password, account, warehouse]):
|
| 254 |
+
print("⚠️ Snowflake credentials not found in environment variables")
|
| 255 |
+
print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 256 |
+
return None
|
| 257 |
+
|
| 258 |
+
return SnowflakeFeedbackConnector(
|
| 259 |
+
user=user,
|
| 260 |
+
password=password,
|
| 261 |
+
account=account,
|
| 262 |
+
warehouse=warehouse,
|
| 263 |
+
database=database,
|
| 264 |
+
schema=schema
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def save_to_snowflake(feedback: UserFeedback) -> bool:
|
| 269 |
+
"""Helper function to save feedback to Snowflake"""
|
| 270 |
+
logger.info("=" * 80)
|
| 271 |
+
logger.info("🔵 SNOWFLAKE SAVE: Starting save process")
|
| 272 |
+
logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
|
| 273 |
+
|
| 274 |
+
connector = get_snowflake_connector_from_env()
|
| 275 |
+
|
| 276 |
+
if not connector:
|
| 277 |
+
logger.warning("⚠️ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
|
| 278 |
+
logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 279 |
+
logger.info("=" * 80)
|
| 280 |
+
return False
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
logger.info("📡 SNOWFLAKE SAVE: Establishing connection...")
|
| 284 |
+
connector.connect()
|
| 285 |
+
logger.info("✅ SNOWFLAKE SAVE: Connection established")
|
| 286 |
+
|
| 287 |
+
logger.info("📥 SNOWFLAKE SAVE: Attempting to insert feedback...")
|
| 288 |
+
success = connector.insert_feedback(feedback)
|
| 289 |
+
|
| 290 |
+
logger.info("🔌 SNOWFLAKE SAVE: Disconnecting...")
|
| 291 |
+
connector.disconnect()
|
| 292 |
+
|
| 293 |
+
if success:
|
| 294 |
+
logger.info("✅ SNOWFLAKE SAVE: Successfully saved feedback")
|
| 295 |
+
else:
|
| 296 |
+
logger.error("❌ SNOWFLAKE SAVE: Failed to save feedback")
|
| 297 |
+
|
| 298 |
+
logger.info("=" * 80)
|
| 299 |
+
return success
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"❌ SNOWFLAKE SAVE ERROR: {type(e).__name__}")
|
| 302 |
+
logger.error(f" - Error: {e}")
|
| 303 |
+
logger.info("=" * 80)
|
| 304 |
+
return False
|
| 305 |
+
|
src/retrieval/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document retrieval and filtering utilities."""
|
| 2 |
+
|
| 3 |
+
from .filter import create_filter, FilterBuilder
|
| 4 |
+
from .context import ContextRetriever, get_context
|
| 5 |
+
from .hybrid import HybridRetriever, get_available_search_modes, get_search_mode_description
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"create_filter",
|
| 9 |
+
"FilterBuilder",
|
| 10 |
+
"ContextRetriever",
|
| 11 |
+
"get_context",
|
| 12 |
+
"HybridRetriever",
|
| 13 |
+
"get_available_search_modes",
|
| 14 |
+
"get_search_mode_description"
|
| 15 |
+
]
|
src/retrieval/colbert_cache.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ColBERT embeddings cache for test set documents.
|
| 3 |
+
Provides O(1) lookup for ColBERT embeddings during late interaction.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Optional, Any
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ColBERTCache:
|
| 13 |
+
"""Cache for ColBERT embeddings of test set documents."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, cache_file: str = "test_set_colbert_cache.json"):
|
| 16 |
+
self.cache_file = Path("outputs/caches") / cache_file
|
| 17 |
+
self.embeddings_cache: Dict[str, np.ndarray] = {}
|
| 18 |
+
self._load_cache()
|
| 19 |
+
|
| 20 |
+
def _load_cache(self):
|
| 21 |
+
"""Load embeddings from cache file."""
|
| 22 |
+
if not self.cache_file.exists():
|
| 23 |
+
print(f"⚠️ ColBERT cache not found: {self.cache_file}")
|
| 24 |
+
print("💡 Run 'python precalculate_test_set_colbert.py' to create cache")
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
print(f"📂 Loading ColBERT cache from {self.cache_file}...")
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
with open(self.cache_file, 'r') as f:
|
| 31 |
+
cache_data = json.load(f)
|
| 32 |
+
|
| 33 |
+
# Reconstruct embeddings from compressed format
|
| 34 |
+
for doc_id, data in cache_data.items():
|
| 35 |
+
embedding_min = data['min']
|
| 36 |
+
embedding_max = data['max']
|
| 37 |
+
quantized_embedding = np.array(data['embedding'], dtype=np.uint8)
|
| 38 |
+
|
| 39 |
+
# Reconstruct original embedding
|
| 40 |
+
reconstructed = (quantized_embedding.astype(np.float32) / 255.0) * (embedding_max - embedding_min) + embedding_min
|
| 41 |
+
self.embeddings_cache[doc_id] = reconstructed.reshape(data['shape'])
|
| 42 |
+
|
| 43 |
+
print(f"✅ Loaded {len(self.embeddings_cache)} ColBERT embeddings from cache")
|
| 44 |
+
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"❌ Error loading ColBERT cache: {e}")
|
| 47 |
+
self.embeddings_cache = {}
|
| 48 |
+
|
| 49 |
+
def get_embedding(self, document_text: str) -> Optional[np.ndarray]:
|
| 50 |
+
"""Get ColBERT embedding for a document (O(1) lookup)."""
|
| 51 |
+
return self.embeddings_cache.get(document_text)
|
| 52 |
+
|
| 53 |
+
def has_embedding(self, document_text: str) -> bool:
|
| 54 |
+
"""Check if embedding exists for document."""
|
| 55 |
+
return document_text in self.embeddings_cache
|
| 56 |
+
|
| 57 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
| 58 |
+
"""Get cache statistics."""
|
| 59 |
+
return {
|
| 60 |
+
'total_embeddings': len(self.embeddings_cache),
|
| 61 |
+
'cache_file': str(self.cache_file),
|
| 62 |
+
'cache_exists': self.cache_file.exists()
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Global cache instance
|
| 67 |
+
_colbert_cache = None
|
| 68 |
+
|
| 69 |
+
def get_colbert_cache() -> ColBERTCache:
|
| 70 |
+
"""Get global ColBERT cache instance."""
|
| 71 |
+
global _colbert_cache
|
| 72 |
+
if _colbert_cache is None:
|
| 73 |
+
_colbert_cache = ColBERTCache()
|
| 74 |
+
return _colbert_cache
|
src/retrieval/context.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Context retrieval with reranking capabilities."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Optional, Tuple, Dict, Any
|
| 5 |
+
from langchain.schema import Document
|
| 6 |
+
from langchain_community.vectorstores import Qdrant
|
| 7 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 8 |
+
from sentence_transformers import CrossEncoder
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from qdrant_client.http import models as rest
|
| 12 |
+
import traceback
|
| 13 |
+
|
| 14 |
+
from .filter import create_filter
|
| 15 |
+
|
| 16 |
+
class ContextRetriever:
|
| 17 |
+
"""
|
| 18 |
+
Context retriever for hybrid search with optional filtering and reranking.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, vectorstore: Qdrant, config: dict = None):
|
| 22 |
+
"""
|
| 23 |
+
Initialize the context retriever.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
vectorstore: Qdrant vector store instance
|
| 27 |
+
config: Configuration dictionary
|
| 28 |
+
"""
|
| 29 |
+
self.vectorstore = vectorstore
|
| 30 |
+
self.config = config or {}
|
| 31 |
+
self.reranker = None
|
| 32 |
+
|
| 33 |
+
# BM25 attributes
|
| 34 |
+
self.bm25_vectorizer = None
|
| 35 |
+
self.bm25_matrix = None
|
| 36 |
+
self.bm25_documents = None
|
| 37 |
+
|
| 38 |
+
# Initialize reranker if available
|
| 39 |
+
# Try to get reranker model from different config paths
|
| 40 |
+
self.reranker_model_name = (
|
| 41 |
+
config.get('retrieval', {}).get('reranker_model') or
|
| 42 |
+
config.get('ranker', {}).get('model') or
|
| 43 |
+
config.get('reranker_model') or
|
| 44 |
+
'BAAI/bge-reranker-v2-m3'
|
| 45 |
+
)
|
| 46 |
+
self.reranker_type = self._detect_reranker_type(self.reranker_model_name)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
if self.reranker_type == 'colbert':
|
| 50 |
+
from colbert.infra import Run, ColBERTConfig
|
| 51 |
+
from colbert.modeling.checkpoint import Checkpoint
|
| 52 |
+
# ColBERT uses late interaction - different implementation needed
|
| 53 |
+
print(f"✅ RERANKER: ColBERT model detected ({self.reranker_model_name})")
|
| 54 |
+
print(f"🔍 INTERACTION TYPE: Late interaction (token-level embeddings)")
|
| 55 |
+
|
| 56 |
+
# Create ColBERT config for CPU mode
|
| 57 |
+
colbert_config = ColBERTConfig(
|
| 58 |
+
doc_maxlen=300,
|
| 59 |
+
query_maxlen=32,
|
| 60 |
+
nbits=2,
|
| 61 |
+
kmeans_niters=4,
|
| 62 |
+
root="./colbert_data"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Load checkpoint (e.g. "colbert-ir/colbertv2.0")
|
| 66 |
+
self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config)
|
| 67 |
+
self.colbert_model = self.colbert_checkpoint.model
|
| 68 |
+
self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer
|
| 69 |
+
self.reranker = self._colbert_rerank # attach wrapper function
|
| 70 |
+
print(f"✅ COLBERT: Model and tokenizer loaded successfully")
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
# Standard CrossEncoder for BGE and other models
|
| 74 |
+
from sentence_transformers import CrossEncoder
|
| 75 |
+
self.reranker = CrossEncoder(self.reranker_model_name)
|
| 76 |
+
print(f"✅ RERANKER: Initialized {self.reranker_model_name}")
|
| 77 |
+
print(f"🔍 INTERACTION TYPE: Cross-encoder (single relevance score)")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"⚠️ Reranker initialization failed: {e}")
|
| 80 |
+
self.reranker = None
|
| 81 |
+
|
| 82 |
+
def _detect_reranker_type(self, model_name: str) -> str:
|
| 83 |
+
"""
|
| 84 |
+
Detect the type of reranker based on model name.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
model_name: Name of the reranker model
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
'colbert' for ColBERT models, 'crossencoder' for others
|
| 91 |
+
"""
|
| 92 |
+
model_name_lower = model_name.lower()
|
| 93 |
+
|
| 94 |
+
# ColBERT model patterns
|
| 95 |
+
colbert_patterns = [
|
| 96 |
+
'colbert',
|
| 97 |
+
'colbert-ir',
|
| 98 |
+
'colbertv2',
|
| 99 |
+
'colbert-v2'
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
for pattern in colbert_patterns:
|
| 103 |
+
if pattern in model_name_lower:
|
| 104 |
+
return 'colbert'
|
| 105 |
+
|
| 106 |
+
# Default to cross-encoder for BGE and other models
|
| 107 |
+
return 'crossencoder'
|
| 108 |
+
|
| 109 |
+
def _similarity_search_with_colbert_embeddings(self, query: str, k: int = 5, **kwargs) -> List[Tuple[Document, float]]:
|
| 110 |
+
"""
|
| 111 |
+
Perform similarity search and fetch ColBERT embeddings for documents.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
query: Search query
|
| 115 |
+
k: Number of documents to retrieve
|
| 116 |
+
**kwargs: Additional search parameters (filter, etc.)
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of (Document, score) tuples with ColBERT embeddings in metadata
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
print(f"🔍 COLBERT RETRIEVAL: Fetching documents with ColBERT embeddings")
|
| 123 |
+
|
| 124 |
+
# Use the vectorstore's similarity_search_with_score method instead of direct client
|
| 125 |
+
# This ensures proper filter handling
|
| 126 |
+
if 'filter' in kwargs and kwargs['filter']:
|
| 127 |
+
# Use the vectorstore method with filter
|
| 128 |
+
result = self.vectorstore.similarity_search_with_score(
|
| 129 |
+
query,
|
| 130 |
+
k=k,
|
| 131 |
+
filter=kwargs['filter']
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
# Use the vectorstore method without filter
|
| 135 |
+
result = self.vectorstore.similarity_search_with_score(query, k=k)
|
| 136 |
+
|
| 137 |
+
# Convert to the format we need
|
| 138 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 139 |
+
documents, scores = result
|
| 140 |
+
elif isinstance(result, list):
|
| 141 |
+
documents = []
|
| 142 |
+
scores = []
|
| 143 |
+
for item in result:
|
| 144 |
+
if isinstance(item, tuple) and len(item) == 2:
|
| 145 |
+
doc, score = item
|
| 146 |
+
documents.append(doc)
|
| 147 |
+
scores.append(score)
|
| 148 |
+
else:
|
| 149 |
+
documents.append(item)
|
| 150 |
+
scores.append(0.0)
|
| 151 |
+
else:
|
| 152 |
+
documents = []
|
| 153 |
+
scores = []
|
| 154 |
+
|
| 155 |
+
# Now we need to fetch the ColBERT embeddings for these documents
|
| 156 |
+
# We'll use the Qdrant client directly for this part since we need specific payload fields
|
| 157 |
+
from qdrant_client.http import models as rest
|
| 158 |
+
|
| 159 |
+
collection_name = self.vectorstore.collection_name
|
| 160 |
+
|
| 161 |
+
# Get document IDs from the retrieved documents
|
| 162 |
+
doc_ids = []
|
| 163 |
+
for doc in documents:
|
| 164 |
+
# Extract ID from document metadata or use page_content hash as fallback
|
| 165 |
+
doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
|
| 166 |
+
if not doc_id:
|
| 167 |
+
# Use a hash of the content as ID
|
| 168 |
+
import hashlib
|
| 169 |
+
doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
|
| 170 |
+
doc_ids.append(doc_id)
|
| 171 |
+
|
| 172 |
+
# Fetch documents with ColBERT embeddings from Qdrant
|
| 173 |
+
search_result = self.vectorstore.client.retrieve(
|
| 174 |
+
collection_name=collection_name,
|
| 175 |
+
ids=doc_ids,
|
| 176 |
+
with_payload=True,
|
| 177 |
+
with_vectors=False
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Convert results to Document objects with ColBERT embeddings
|
| 181 |
+
enhanced_documents = []
|
| 182 |
+
enhanced_scores = []
|
| 183 |
+
|
| 184 |
+
# Create a mapping from doc_id to original score
|
| 185 |
+
doc_id_to_score = {}
|
| 186 |
+
for i, doc in enumerate(documents):
|
| 187 |
+
doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
|
| 188 |
+
if not doc_id:
|
| 189 |
+
import hashlib
|
| 190 |
+
doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
|
| 191 |
+
doc_id_to_score[doc_id] = scores[i]
|
| 192 |
+
|
| 193 |
+
for point in search_result:
|
| 194 |
+
# Extract payload
|
| 195 |
+
payload = point.payload
|
| 196 |
+
|
| 197 |
+
# Get the original score for this document
|
| 198 |
+
doc_id = str(point.id)
|
| 199 |
+
original_score = doc_id_to_score.get(doc_id, 0.0)
|
| 200 |
+
|
| 201 |
+
# Create Document object with ColBERT embeddings
|
| 202 |
+
doc = Document(
|
| 203 |
+
page_content=payload.get('page_content', ''),
|
| 204 |
+
metadata={
|
| 205 |
+
**payload.get('metadata', {}),
|
| 206 |
+
'colbert_embedding': payload.get('colbert_embedding'),
|
| 207 |
+
'colbert_model': payload.get('colbert_model'),
|
| 208 |
+
'colbert_calculated_at': payload.get('colbert_calculated_at')
|
| 209 |
+
}
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
enhanced_documents.append(doc)
|
| 213 |
+
enhanced_scores.append(original_score)
|
| 214 |
+
|
| 215 |
+
print(f"✅ COLBERT RETRIEVAL: Retrieved {len(enhanced_documents)} documents with ColBERT embeddings")
|
| 216 |
+
|
| 217 |
+
return list(zip(enhanced_documents, enhanced_scores))
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(f"❌ COLBERT RETRIEVAL ERROR: {e}")
|
| 221 |
+
print(f"❌ Falling back to regular similarity search")
|
| 222 |
+
|
| 223 |
+
# Fallback to regular search - handle filter parameter correctly
|
| 224 |
+
if 'filter' in kwargs and kwargs['filter']:
|
| 225 |
+
return self.vectorstore.similarity_search_with_score(query, k=k, filter=kwargs['filter'])
|
| 226 |
+
else:
|
| 227 |
+
return self.vectorstore.similarity_search_with_score(query, k=k)
|
| 228 |
+
|
| 229 |
+
def retrieve_context(
|
| 230 |
+
self,
|
| 231 |
+
query: str,
|
| 232 |
+
k: int = 5,
|
| 233 |
+
reports: Optional[List[str]] = None,
|
| 234 |
+
sources: Optional[List[str]] = None,
|
| 235 |
+
subtype: Optional[str] = None,
|
| 236 |
+
year: Optional[str] = None,
|
| 237 |
+
district: Optional[List[str]] = None,
|
| 238 |
+
filenames: Optional[List[str]] = None,
|
| 239 |
+
use_reranking: bool = False,
|
| 240 |
+
qdrant_filter: Optional[rest.Filter] = None
|
| 241 |
+
) -> List[Document]:
|
| 242 |
+
"""
|
| 243 |
+
Retrieve context documents using hybrid search with optional filtering and reranking.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
query: User query
|
| 247 |
+
top_k: Number of documents to retrieve
|
| 248 |
+
reports: List of report names to filter by
|
| 249 |
+
sources: List of sources to filter by
|
| 250 |
+
subtype: Document subtype to filter by
|
| 251 |
+
year: Year to filter by
|
| 252 |
+
use_reranking: Whether to apply reranking
|
| 253 |
+
qdrant_filter: Pre-built Qdrant filter to use
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
List of retrieved documents
|
| 257 |
+
"""
|
| 258 |
+
try:
|
| 259 |
+
# Determine how many documents to retrieve
|
| 260 |
+
retrieve_k = k #* 3 if use_reranking else k # Retrieve more for reranking
|
| 261 |
+
|
| 262 |
+
# Build search kwargs
|
| 263 |
+
search_kwargs = {}
|
| 264 |
+
|
| 265 |
+
# Use qdrant_filter if provided (this takes precedence)
|
| 266 |
+
if qdrant_filter:
|
| 267 |
+
search_kwargs = {"filter": qdrant_filter}
|
| 268 |
+
print(f"✅ FILTERS APPLIED: Using inferred Qdrant filter")
|
| 269 |
+
else:
|
| 270 |
+
# Build filter from individual parameters
|
| 271 |
+
filter_obj = create_filter(
|
| 272 |
+
reports=reports,
|
| 273 |
+
sources=sources,
|
| 274 |
+
subtype=subtype,
|
| 275 |
+
year=year,
|
| 276 |
+
district=district,
|
| 277 |
+
filenames=filenames
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if filter_obj:
|
| 281 |
+
search_kwargs = {"filter": filter_obj}
|
| 282 |
+
print(f"✅ FILTERS APPLIED: Using built filter")
|
| 283 |
+
else:
|
| 284 |
+
search_kwargs = {}
|
| 285 |
+
print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
|
| 286 |
+
|
| 287 |
+
# Perform vector search
|
| 288 |
+
try:
|
| 289 |
+
# Check if we need ColBERT embeddings for reranking
|
| 290 |
+
if use_reranking and self.reranker_type == 'colbert':
|
| 291 |
+
result = self._similarity_search_with_colbert_embeddings(
|
| 292 |
+
query,
|
| 293 |
+
k=retrieve_k,
|
| 294 |
+
**search_kwargs
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
result = self.vectorstore.similarity_search_with_score(
|
| 298 |
+
query,
|
| 299 |
+
k=retrieve_k,
|
| 300 |
+
**search_kwargs
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Handle different return formats
|
| 304 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 305 |
+
documents, scores = result
|
| 306 |
+
elif isinstance(result, list) and len(result) > 0:
|
| 307 |
+
# Handle case where result is a list of (Document, score) tuples
|
| 308 |
+
documents = []
|
| 309 |
+
scores = []
|
| 310 |
+
for item in result:
|
| 311 |
+
if isinstance(item, tuple) and len(item) == 2:
|
| 312 |
+
doc, score = item
|
| 313 |
+
documents.append(doc)
|
| 314 |
+
scores.append(score)
|
| 315 |
+
else:
|
| 316 |
+
# Handle case where item is just a Document
|
| 317 |
+
documents.append(item)
|
| 318 |
+
scores.append(0.0) # Default score
|
| 319 |
+
else:
|
| 320 |
+
documents = []
|
| 321 |
+
scores = []
|
| 322 |
+
|
| 323 |
+
print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(documents)} documents (requested: {retrieve_k})")
|
| 324 |
+
|
| 325 |
+
# If we got fewer documents than requested, try without filters
|
| 326 |
+
if len(documents) < retrieve_k and search_kwargs.get('filter'):
|
| 327 |
+
print(f"⚠️ RETRIEVAL: Got {len(documents)} docs with filters, trying without filters...")
|
| 328 |
+
try:
|
| 329 |
+
result_no_filter = self.vectorstore.similarity_search_with_score(
|
| 330 |
+
query,
|
| 331 |
+
k=retrieve_k
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if isinstance(result_no_filter, tuple) and len(result_no_filter) == 2:
|
| 335 |
+
documents_no_filter, scores_no_filter = result_no_filter
|
| 336 |
+
elif isinstance(result_no_filter, list):
|
| 337 |
+
documents_no_filter = []
|
| 338 |
+
scores_no_filter = []
|
| 339 |
+
for item in result_no_filter:
|
| 340 |
+
if isinstance(item, tuple) and len(item) == 2:
|
| 341 |
+
doc, score = item
|
| 342 |
+
documents_no_filter.append(doc)
|
| 343 |
+
scores_no_filter.append(score)
|
| 344 |
+
else:
|
| 345 |
+
documents_no_filter.append(item)
|
| 346 |
+
scores_no_filter.append(0.0)
|
| 347 |
+
else:
|
| 348 |
+
documents_no_filter = []
|
| 349 |
+
scores_no_filter = []
|
| 350 |
+
|
| 351 |
+
if len(documents_no_filter) > len(documents):
|
| 352 |
+
print(f"✅ RETRIEVAL: Got {len(documents_no_filter)} docs without filters")
|
| 353 |
+
documents = documents_no_filter
|
| 354 |
+
scores = scores_no_filter
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print(f"⚠️ RETRIEVAL: Fallback search failed: {e}")
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
print(f"❌ RETRIEVAL ERROR: {str(e)}")
|
| 360 |
+
return []
|
| 361 |
+
|
| 362 |
+
# Apply reranking if enabled
|
| 363 |
+
reranking_applied = False
|
| 364 |
+
if use_reranking and len(documents) > 1:
|
| 365 |
+
print(f"🔄 RERANKING: Applying {self.reranker_model_name} to {len(documents)} documents...")
|
| 366 |
+
try:
|
| 367 |
+
original_docs = documents.copy()
|
| 368 |
+
original_scores = scores.copy()
|
| 369 |
+
|
| 370 |
+
# Apply reranking
|
| 371 |
+
# print(f"🔍 ORIGINAL DOCS: {documents[0]}")
|
| 372 |
+
reranked_docs = self._apply_reranking(query, documents, scores)
|
| 373 |
+
# print(f"🔍 RERANKED DOCS: {reranked_docs[0]}")
|
| 374 |
+
reranking_applied = len(reranked_docs) > 0
|
| 375 |
+
|
| 376 |
+
if reranking_applied:
|
| 377 |
+
print(f"✅ RERANKING APPLIED: {self.reranker_model_name}")
|
| 378 |
+
documents = reranked_docs
|
| 379 |
+
# Update scores to reflect reranking
|
| 380 |
+
# scores = [0.0] * len(documents) # Reranked scores are not directly comparable
|
| 381 |
+
else:
|
| 382 |
+
print(f"⚠️ RERANKING FAILED: Using original order")
|
| 383 |
+
documents = original_docs
|
| 384 |
+
scores = original_scores
|
| 385 |
+
return documents
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print(f"❌ RERANKING ERROR: {str(e)}")
|
| 389 |
+
print(f"⚠️ RERANKING FAILED: Using original order")
|
| 390 |
+
reranking_applied = False
|
| 391 |
+
elif use_reranking and len(documents) <= 1:
|
| 392 |
+
print(f"ℹ️ RERANKING: Skipped (only {len(documents)} document(s) retrieved)")
|
| 393 |
+
if use_reranking:
|
| 394 |
+
print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
|
| 395 |
+
# Store original scores in metadata
|
| 396 |
+
for i, (doc, score) in enumerate(zip(documents, scores)):
|
| 397 |
+
doc.metadata['original_score'] = float(score)
|
| 398 |
+
doc.metadata['reranking_applied'] = False
|
| 399 |
+
return documents
|
| 400 |
+
else:
|
| 401 |
+
print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
|
| 402 |
+
|
| 403 |
+
# Limit to requested number of documents
|
| 404 |
+
documents = documents[:k]
|
| 405 |
+
scores = scores[:k] if scores else [0.0] * len(documents)
|
| 406 |
+
|
| 407 |
+
# Add metadata to documents
|
| 408 |
+
for i, (doc, score) in enumerate(zip(documents, scores)):
|
| 409 |
+
if hasattr(doc, 'metadata'):
|
| 410 |
+
doc.metadata.update({
|
| 411 |
+
'reranking_applied': reranking_applied,
|
| 412 |
+
'reranker_model': 'BAAI/bge-reranker-v2-m3' if reranking_applied else None,
|
| 413 |
+
'original_rank': i + 1,
|
| 414 |
+
'final_rank': i + 1,
|
| 415 |
+
'original_score': float(score) if score is not None else 0.0
|
| 416 |
+
})
|
| 417 |
+
|
| 418 |
+
return documents
|
| 419 |
+
|
| 420 |
+
except Exception as e:
|
| 421 |
+
print(f"❌ CONTEXT RETRIEVAL ERROR: {str(e)}")
|
| 422 |
+
return []
|
| 423 |
+
|
| 424 |
+
def _apply_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 425 |
+
"""
|
| 426 |
+
Apply reranking to documents using the appropriate reranker.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
query: User query
|
| 430 |
+
documents: List of documents to rerank
|
| 431 |
+
scores: Original scores
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
Reranked list of documents
|
| 435 |
+
"""
|
| 436 |
+
if not self.reranker or len(documents) == 0:
|
| 437 |
+
return documents
|
| 438 |
+
|
| 439 |
+
try:
|
| 440 |
+
print(f"🔍 RERANKING METHOD: Starting reranking with {len(documents)} documents")
|
| 441 |
+
print(f"🔍 RERANKING TYPE: {self.reranker_type.upper()}")
|
| 442 |
+
|
| 443 |
+
if self.reranker_type == 'colbert':
|
| 444 |
+
return self._apply_colbert_reranking(query, documents, scores)
|
| 445 |
+
else:
|
| 446 |
+
return self._apply_crossencoder_reranking(query, documents, scores)
|
| 447 |
+
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"❌ RERANKING ERROR: {str(e)}")
|
| 450 |
+
return documents
|
| 451 |
+
|
| 452 |
+
def _apply_crossencoder_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 453 |
+
"""
|
| 454 |
+
Apply reranking using CrossEncoder (BGE and other models).
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
query: User query
|
| 458 |
+
documents: List of documents to rerank
|
| 459 |
+
scores: Original scores
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Reranked list of documents
|
| 463 |
+
"""
|
| 464 |
+
# Prepare pairs for reranking
|
| 465 |
+
pairs = []
|
| 466 |
+
for doc in documents:
|
| 467 |
+
pairs.append([query, doc.page_content])
|
| 468 |
+
|
| 469 |
+
print(f"🔍 CROSS-ENCODER: Prepared {len(pairs)} pairs for reranking")
|
| 470 |
+
|
| 471 |
+
# Get reranking scores using the correct CrossEncoder API
|
| 472 |
+
rerank_scores = self.reranker.predict(pairs)
|
| 473 |
+
|
| 474 |
+
# Handle single score case
|
| 475 |
+
if not isinstance(rerank_scores, (list, np.ndarray)):
|
| 476 |
+
rerank_scores = [rerank_scores]
|
| 477 |
+
|
| 478 |
+
# Ensure we have the right number of scores
|
| 479 |
+
if len(rerank_scores) != len(documents):
|
| 480 |
+
print(f"⚠️ RERANKING WARNING: Expected {len(documents)} scores, got {len(rerank_scores)}")
|
| 481 |
+
return documents
|
| 482 |
+
|
| 483 |
+
print(f"🔍 CROSS-ENCODER: Got {len(rerank_scores)} rerank scores")
|
| 484 |
+
print(f"🔍 CROSS-ENCODER SCORES: {rerank_scores[:5]}...") # Show first 5 scores
|
| 485 |
+
|
| 486 |
+
# Combine documents with their rerank scores
|
| 487 |
+
doc_scores = list(zip(documents, rerank_scores))
|
| 488 |
+
|
| 489 |
+
# Sort by rerank score (descending)
|
| 490 |
+
doc_scores.sort(key=lambda x: x[1], reverse=True)
|
| 491 |
+
|
| 492 |
+
# Extract reranked documents and store scores in metadata
|
| 493 |
+
reranked_docs = []
|
| 494 |
+
for i, (doc, rerank_score) in enumerate(doc_scores):
|
| 495 |
+
# Find original index for original score
|
| 496 |
+
original_idx = documents.index(doc)
|
| 497 |
+
original_score = scores[original_idx] if original_idx < len(scores) else 0.0
|
| 498 |
+
|
| 499 |
+
# Create new document with reranking metadata
|
| 500 |
+
new_doc = Document(
|
| 501 |
+
page_content=doc.page_content,
|
| 502 |
+
metadata={
|
| 503 |
+
**doc.metadata,
|
| 504 |
+
'reranking_applied': True,
|
| 505 |
+
'reranker_model': self.reranker_model_name,
|
| 506 |
+
'reranker_type': self.reranker_type,
|
| 507 |
+
'original_rank': original_idx + 1,
|
| 508 |
+
'final_rank': i + 1,
|
| 509 |
+
'original_score': float(original_score),
|
| 510 |
+
'reranked_score': float(rerank_score)
|
| 511 |
+
}
|
| 512 |
+
)
|
| 513 |
+
reranked_docs.append(new_doc)
|
| 514 |
+
|
| 515 |
+
print(f"✅ CROSS-ENCODER: Reranked {len(reranked_docs)} documents")
|
| 516 |
+
|
| 517 |
+
return reranked_docs
|
| 518 |
+
|
| 519 |
+
def _apply_colbert_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 520 |
+
"""
|
| 521 |
+
Apply reranking using ColBERT late interaction.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
query: User query
|
| 525 |
+
documents: List of documents to rerank
|
| 526 |
+
scores: Original scores
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
Reranked list of documents
|
| 530 |
+
"""
|
| 531 |
+
# Use the actual ColBERT reranking implementation
|
| 532 |
+
return self._colbert_rerank(query, documents, scores)
|
| 533 |
+
|
| 534 |
+
def _colbert_rerank(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 535 |
+
"""
|
| 536 |
+
ColBERT reranking using late interaction with pre-calculated embeddings support.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
query: User query
|
| 540 |
+
documents: List of documents to rerank
|
| 541 |
+
scores: Original scores
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
Reranked list of documents
|
| 545 |
+
"""
|
| 546 |
+
try:
|
| 547 |
+
print(f"🔍 COLBERT: Starting late interaction reranking with {len(documents)} documents")
|
| 548 |
+
|
| 549 |
+
# Check if documents have pre-calculated ColBERT embeddings
|
| 550 |
+
pre_calculated_embeddings = []
|
| 551 |
+
documents_without_embeddings = []
|
| 552 |
+
documents_without_indices = []
|
| 553 |
+
|
| 554 |
+
for i, doc in enumerate(documents):
|
| 555 |
+
if (hasattr(doc, 'metadata') and
|
| 556 |
+
'colbert_embedding' in doc.metadata and
|
| 557 |
+
doc.metadata['colbert_embedding'] is not None):
|
| 558 |
+
# Use pre-calculated embedding
|
| 559 |
+
colbert_embedding = doc.metadata['colbert_embedding']
|
| 560 |
+
if isinstance(colbert_embedding, list):
|
| 561 |
+
colbert_embedding = torch.tensor(colbert_embedding)
|
| 562 |
+
pre_calculated_embeddings.append(colbert_embedding)
|
| 563 |
+
else:
|
| 564 |
+
# Need to calculate embedding
|
| 565 |
+
documents_without_embeddings.append(doc)
|
| 566 |
+
documents_without_indices.append(i)
|
| 567 |
+
|
| 568 |
+
# Calculate query embedding
|
| 569 |
+
query_embeddings = self.colbert_checkpoint.queryFromText([query])
|
| 570 |
+
|
| 571 |
+
# Calculate embeddings for documents without pre-calculated ones
|
| 572 |
+
if documents_without_embeddings:
|
| 573 |
+
print(f"🔄 COLBERT: Calculating embeddings for {len(documents_without_embeddings)} documents without pre-calculated embeddings")
|
| 574 |
+
doc_texts = [doc.page_content for doc in documents_without_embeddings]
|
| 575 |
+
doc_embeddings = self.colbert_checkpoint.docFromText(doc_texts)
|
| 576 |
+
|
| 577 |
+
# Insert calculated embeddings into the right positions
|
| 578 |
+
for i, embedding in enumerate(doc_embeddings):
|
| 579 |
+
idx = documents_without_indices[i]
|
| 580 |
+
pre_calculated_embeddings.insert(idx, embedding)
|
| 581 |
+
else:
|
| 582 |
+
print(f"✅ COLBERT: Using pre-calculated embeddings for all {len(documents)} documents")
|
| 583 |
+
|
| 584 |
+
# Calculate late interaction scores
|
| 585 |
+
# ColBERT uses MaxSim: for each query token, find max similarity with document tokens
|
| 586 |
+
colbert_scores = []
|
| 587 |
+
for i, doc_embedding in enumerate(pre_calculated_embeddings):
|
| 588 |
+
# Calculate similarity matrix between query and document i
|
| 589 |
+
sim_matrix = torch.matmul(query_embeddings[0], doc_embedding.transpose(-1, -2))
|
| 590 |
+
|
| 591 |
+
# MaxSim: for each query token, take max similarity with document
|
| 592 |
+
max_sim_per_query_token = torch.max(sim_matrix, dim=-1)[0]
|
| 593 |
+
|
| 594 |
+
# Sum over query tokens to get final score
|
| 595 |
+
final_score = torch.sum(max_sim_per_query_token).item()
|
| 596 |
+
colbert_scores.append(final_score)
|
| 597 |
+
|
| 598 |
+
# Sort documents by ColBERT scores
|
| 599 |
+
doc_scores = list(zip(documents, colbert_scores))
|
| 600 |
+
doc_scores.sort(key=lambda x: x[1], reverse=True)
|
| 601 |
+
|
| 602 |
+
# Create reranked documents with metadata
|
| 603 |
+
reranked_docs = []
|
| 604 |
+
for i, (doc, colbert_score) in enumerate(doc_scores):
|
| 605 |
+
original_idx = documents.index(doc)
|
| 606 |
+
original_score = scores[original_idx] if original_idx < len(scores) else 0.0
|
| 607 |
+
|
| 608 |
+
new_doc = Document(
|
| 609 |
+
page_content=doc.page_content,
|
| 610 |
+
metadata={
|
| 611 |
+
**doc.metadata,
|
| 612 |
+
'reranking_applied': True,
|
| 613 |
+
'reranker_model': self.reranker_model_name,
|
| 614 |
+
'reranker_type': self.reranker_type,
|
| 615 |
+
'original_rank': original_idx + 1,
|
| 616 |
+
'final_rank': i + 1,
|
| 617 |
+
'original_score': float(original_score),
|
| 618 |
+
'reranked_score': float(colbert_score),
|
| 619 |
+
'colbert_score': float(colbert_score),
|
| 620 |
+
'colbert_embedding_pre_calculated': 'colbert_embedding' in doc.metadata
|
| 621 |
+
}
|
| 622 |
+
)
|
| 623 |
+
reranked_docs.append(new_doc)
|
| 624 |
+
|
| 625 |
+
print(f"✅ COLBERT: Reranked {len(reranked_docs)} documents using late interaction")
|
| 626 |
+
print(f"🔍 COLBERT SCORES: {[f'{score:.4f}' for score in colbert_scores[:5]]}...")
|
| 627 |
+
|
| 628 |
+
return reranked_docs
|
| 629 |
+
|
| 630 |
+
except Exception as e:
|
| 631 |
+
print(f"❌ COLBERT RERANKING ERROR: {str(e)}")
|
| 632 |
+
print(f"❌ COLBERT TRACEBACK: {traceback.format_exc()}")
|
| 633 |
+
# Fallback to original order - return documents as-is
|
| 634 |
+
return documents
|
| 635 |
+
|
| 636 |
+
def retrieve_with_scores(self, query: str, vectorstore=None, k: int = 5, reports: List[str] = None,
|
| 637 |
+
sources: List[str] = None, subtype: List[str] = None,
|
| 638 |
+
year: List[str] = None, use_reranking: bool = False,
|
| 639 |
+
qdrant_filter: Optional[rest.Filter] = None) -> Tuple[List[Document], List[float]]:
|
| 640 |
+
"""
|
| 641 |
+
Retrieve context documents with scores using hybrid search with optional reranking.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
query: User query
|
| 645 |
+
vectorstore: Optional vectorstore instance (for compatibility)
|
| 646 |
+
k: Number of documents to retrieve
|
| 647 |
+
reports: List of report names to filter by
|
| 648 |
+
sources: List of sources to filter by
|
| 649 |
+
subtype: Document subtype to filter by
|
| 650 |
+
year: List of years to filter by
|
| 651 |
+
use_reranking: Whether to apply reranking
|
| 652 |
+
qdrant_filter: Pre-built Qdrant filter
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
Tuple of (documents, scores)
|
| 656 |
+
"""
|
| 657 |
+
try:
|
| 658 |
+
# Use the provided vectorstore if available, otherwise use the instance one
|
| 659 |
+
if vectorstore:
|
| 660 |
+
self.vectorstore = vectorstore
|
| 661 |
+
|
| 662 |
+
# Determine search strategy
|
| 663 |
+
search_strategy = self.config.get('retrieval', {}).get('search_strategy', 'vector_only')
|
| 664 |
+
|
| 665 |
+
if search_strategy == 'vector_only':
|
| 666 |
+
# Vector search only
|
| 667 |
+
print(f"🔄 VECTOR SEARCH: Retrieving {k} documents...")
|
| 668 |
+
|
| 669 |
+
if qdrant_filter:
|
| 670 |
+
print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
|
| 671 |
+
# Pass filter as positional argument, not keyword argument
|
| 672 |
+
results = self.vectorstore.similarity_search_with_score(
|
| 673 |
+
query,
|
| 674 |
+
k=k,
|
| 675 |
+
filter=qdrant_filter
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
# Build filter from individual parameters
|
| 679 |
+
filter_conditions = self._build_filter_conditions(reports, sources, subtype, year)
|
| 680 |
+
if filter_conditions:
|
| 681 |
+
print(f"✅ FILTER APPLIED: {filter_conditions}")
|
| 682 |
+
results = self.vectorstore.similarity_search_with_score(
|
| 683 |
+
query,
|
| 684 |
+
k=k,
|
| 685 |
+
filter=filter_conditions
|
| 686 |
+
)
|
| 687 |
+
else:
|
| 688 |
+
print(f"ℹ️ NO FILTERS APPLIED: All documents will be searched")
|
| 689 |
+
results = self.vectorstore.similarity_search_with_score(query, k=k)
|
| 690 |
+
|
| 691 |
+
print(f"🔍 SEARCH DEBUG: Raw result type: {type(results)}")
|
| 692 |
+
print(f"🔍 SEARCH DEBUG: Raw result length: {len(results)}")
|
| 693 |
+
|
| 694 |
+
# Handle different result formats
|
| 695 |
+
if results and isinstance(results[0], tuple):
|
| 696 |
+
documents = [doc for doc, score in results]
|
| 697 |
+
scores = [score for doc, score in results]
|
| 698 |
+
print(f"🔍 SEARCH DEBUG: After unpacking - documents: {len(documents)}, scores: {len(scores)}")
|
| 699 |
+
else:
|
| 700 |
+
documents = results
|
| 701 |
+
scores = [0.0] * len(documents)
|
| 702 |
+
print(f"🔍 SEARCH DEBUG: No scores available, using default")
|
| 703 |
+
|
| 704 |
+
print(f"🔧 CONVERTING: Converting {len(documents)} documents")
|
| 705 |
+
|
| 706 |
+
# Convert to Document objects and store original scores
|
| 707 |
+
final_documents = []
|
| 708 |
+
for i, (doc, score) in enumerate(zip(documents, scores)):
|
| 709 |
+
if hasattr(doc, 'page_content'):
|
| 710 |
+
new_doc = Document(
|
| 711 |
+
page_content=doc.page_content,
|
| 712 |
+
metadata=doc.metadata.copy()
|
| 713 |
+
)
|
| 714 |
+
# Store original score in metadata
|
| 715 |
+
new_doc.metadata['original_score'] = float(score) if score is not None else 0.0
|
| 716 |
+
final_documents.append(new_doc)
|
| 717 |
+
else:
|
| 718 |
+
print(f"⚠️ WARNING: Document {i} has no page_content")
|
| 719 |
+
|
| 720 |
+
print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(final_documents)} documents")
|
| 721 |
+
|
| 722 |
+
# Apply reranking if enabled
|
| 723 |
+
if use_reranking and len(final_documents) > 1:
|
| 724 |
+
print(f"🔄 RERANKING: Applying {self.reranker_model} to {len(final_documents)} documents...")
|
| 725 |
+
final_documents = self._apply_reranking(query, final_documents, scores)
|
| 726 |
+
print(f"✅ RERANKING APPLIED: {self.reranker_model}")
|
| 727 |
+
else:
|
| 728 |
+
print(f"ℹ️ RERANKING: Skipped (disabled or no documents)")
|
| 729 |
+
|
| 730 |
+
return final_documents, scores
|
| 731 |
+
|
| 732 |
+
else:
|
| 733 |
+
print(f"❌ UNSUPPORTED STRATEGY: {search_strategy}")
|
| 734 |
+
return [], []
|
| 735 |
+
|
| 736 |
+
except Exception as e:
|
| 737 |
+
print(f"❌ RETRIEVAL ERROR: {e}")
|
| 738 |
+
print(f"❌ RETRIEVAL TRACEBACK: {traceback.format_exc()}")
|
| 739 |
+
return [], []
|
| 740 |
+
|
| 741 |
+
def _build_filter_conditions(self, reports: List[str] = None, sources: List[str] = None,
|
| 742 |
+
subtype: List[str] = None, year: List[str] = None) -> Optional[rest.Filter]:
|
| 743 |
+
"""
|
| 744 |
+
Build Qdrant filter conditions from individual parameters.
|
| 745 |
+
|
| 746 |
+
Args:
|
| 747 |
+
reports: List of report names
|
| 748 |
+
sources: List of sources
|
| 749 |
+
subtype: Document subtype
|
| 750 |
+
year: List of years
|
| 751 |
+
|
| 752 |
+
Returns:
|
| 753 |
+
Qdrant filter or None
|
| 754 |
+
"""
|
| 755 |
+
conditions = []
|
| 756 |
+
|
| 757 |
+
if reports:
|
| 758 |
+
conditions.append(rest.FieldCondition(
|
| 759 |
+
key="metadata.filename",
|
| 760 |
+
match=rest.MatchAny(any=reports)
|
| 761 |
+
))
|
| 762 |
+
|
| 763 |
+
if sources:
|
| 764 |
+
conditions.append(rest.FieldCondition(
|
| 765 |
+
key="metadata.source",
|
| 766 |
+
match=rest.MatchAny(any=sources)
|
| 767 |
+
))
|
| 768 |
+
|
| 769 |
+
if subtype:
|
| 770 |
+
conditions.append(rest.FieldCondition(
|
| 771 |
+
key="metadata.subtype",
|
| 772 |
+
match=rest.MatchAny(any=subtype)
|
| 773 |
+
))
|
| 774 |
+
|
| 775 |
+
if year:
|
| 776 |
+
conditions.append(rest.FieldCondition(
|
| 777 |
+
key="metadata.year",
|
| 778 |
+
match=rest.MatchAny(any=year)
|
| 779 |
+
))
|
| 780 |
+
|
| 781 |
+
if conditions:
|
| 782 |
+
return rest.Filter(must=conditions)
|
| 783 |
+
|
| 784 |
+
return None
|
| 785 |
+
|
| 786 |
+
def get_context(
|
| 787 |
+
query: str,
|
| 788 |
+
vectorstore: Qdrant,
|
| 789 |
+
k: int = 5,
|
| 790 |
+
reports: Optional[List[str]] = None,
|
| 791 |
+
sources: Optional[List[str]] = None,
|
| 792 |
+
subtype: Optional[str] = None,
|
| 793 |
+
year: Optional[str] = None,
|
| 794 |
+
use_reranking: bool = False,
|
| 795 |
+
qdrant_filter: Optional[rest.Filter] = None
|
| 796 |
+
) -> List[Document]:
|
| 797 |
+
"""
|
| 798 |
+
Convenience function to get context documents.
|
| 799 |
+
|
| 800 |
+
Args:
|
| 801 |
+
query: User query
|
| 802 |
+
vectorstore: Qdrant vector store instance
|
| 803 |
+
k: Number of documents to retrieve
|
| 804 |
+
reports: Optional list of report names to filter by
|
| 805 |
+
sources: Optional list of source categories to filter by
|
| 806 |
+
subtype: Optional subtype to filter by
|
| 807 |
+
year: Optional year to filter by
|
| 808 |
+
use_reranking: Whether to apply reranking
|
| 809 |
+
qdrant_filter: Optional pre-built Qdrant filter
|
| 810 |
+
|
| 811 |
+
Returns:
|
| 812 |
+
List of retrieved documents
|
| 813 |
+
"""
|
| 814 |
+
retriever = ContextRetriever(vectorstore)
|
| 815 |
+
return retriever.retrieve_context(
|
| 816 |
+
query=query,
|
| 817 |
+
k=k,
|
| 818 |
+
reports=reports,
|
| 819 |
+
sources=sources,
|
| 820 |
+
subtype=subtype,
|
| 821 |
+
year=year,
|
| 822 |
+
use_reranking=use_reranking,
|
| 823 |
+
qdrant_filter=qdrant_filter
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def format_context_for_llm(documents: List[Document]) -> str:
|
| 828 |
+
"""
|
| 829 |
+
Format retrieved documents for LLM input.
|
| 830 |
+
|
| 831 |
+
Args:
|
| 832 |
+
documents: List of Document objects
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
Formatted string for LLM
|
| 836 |
+
"""
|
| 837 |
+
if not documents:
|
| 838 |
+
return ""
|
| 839 |
+
|
| 840 |
+
formatted_parts = []
|
| 841 |
+
for i, doc in enumerate(documents, 1):
|
| 842 |
+
content = doc.page_content.strip()
|
| 843 |
+
source = doc.metadata.get('filename', 'Unknown')
|
| 844 |
+
|
| 845 |
+
formatted_parts.append(f"Document {i} (Source: {source}):\n{content}")
|
| 846 |
+
|
| 847 |
+
return "\n\n".join(formatted_parts)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def get_context_metadata(documents: List[Document]) -> Dict[str, Any]:
|
| 851 |
+
"""
|
| 852 |
+
Extract metadata summary from retrieved documents.
|
| 853 |
+
|
| 854 |
+
Args:
|
| 855 |
+
documents: List of Document objects
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
Dictionary with metadata summary
|
| 859 |
+
"""
|
| 860 |
+
if not documents:
|
| 861 |
+
return {}
|
| 862 |
+
|
| 863 |
+
sources = set()
|
| 864 |
+
years = set()
|
| 865 |
+
doc_types = set()
|
| 866 |
+
|
| 867 |
+
for doc in documents:
|
| 868 |
+
metadata = doc.metadata
|
| 869 |
+
if 'filename' in metadata:
|
| 870 |
+
sources.add(metadata['filename'])
|
| 871 |
+
if 'year' in metadata:
|
| 872 |
+
years.add(metadata['year'])
|
| 873 |
+
if 'source' in metadata:
|
| 874 |
+
doc_types.add(metadata['source'])
|
| 875 |
+
|
| 876 |
+
return {
|
| 877 |
+
"num_documents": len(documents),
|
| 878 |
+
"sources": list(sources),
|
| 879 |
+
"years": list(years),
|
| 880 |
+
"document_types": list(doc_types)
|
| 881 |
+
}
|
src/retrieval/filter.py
ADDED
|
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document filtering utilities for Qdrant vector store."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Union, Dict, Tuple, Any
|
| 4 |
+
from qdrant_client.http import models as rest
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class FilterBuilder:
|
| 9 |
+
"""Builder class for creating Qdrant filters."""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.conditions = []
|
| 13 |
+
|
| 14 |
+
def add_source_filter(self, source: Union[str, List[str]]) -> 'FilterBuilder':
|
| 15 |
+
"""Add source filter condition."""
|
| 16 |
+
if source:
|
| 17 |
+
if isinstance(source, list):
|
| 18 |
+
condition = rest.FieldCondition(
|
| 19 |
+
key="metadata.source",
|
| 20 |
+
match=rest.MatchAny(any=source)
|
| 21 |
+
)
|
| 22 |
+
print(f"🔧 FilterBuilder: Added source filter for {source}")
|
| 23 |
+
else:
|
| 24 |
+
condition = rest.FieldCondition(
|
| 25 |
+
key="metadata.source",
|
| 26 |
+
match=rest.MatchValue(value=source)
|
| 27 |
+
)
|
| 28 |
+
print(f"🔧 FilterBuilder: Added source filter for '{source}'")
|
| 29 |
+
self.conditions.append(condition)
|
| 30 |
+
return self
|
| 31 |
+
|
| 32 |
+
def add_filename_filter(self, filenames: List[str]) -> 'FilterBuilder':
|
| 33 |
+
"""Add filename filter condition."""
|
| 34 |
+
if filenames:
|
| 35 |
+
condition = rest.FieldCondition(
|
| 36 |
+
key="metadata.filename",
|
| 37 |
+
match=rest.MatchAny(any=filenames)
|
| 38 |
+
)
|
| 39 |
+
self.conditions.append(condition)
|
| 40 |
+
print(f"🔧 FilterBuilder: Added filename filter for {filenames}")
|
| 41 |
+
return self
|
| 42 |
+
|
| 43 |
+
def add_year_filter(self, years: List[str]) -> 'FilterBuilder':
|
| 44 |
+
"""Add year filter condition."""
|
| 45 |
+
if years:
|
| 46 |
+
condition = rest.FieldCondition(
|
| 47 |
+
key="metadata.year",
|
| 48 |
+
match=rest.MatchAny(any=years)
|
| 49 |
+
)
|
| 50 |
+
self.conditions.append(condition)
|
| 51 |
+
print(f"🔧 FilterBuilder: Added year filter for {years}")
|
| 52 |
+
return self
|
| 53 |
+
|
| 54 |
+
def add_district_filter(self, districts: List[str]) -> 'FilterBuilder':
|
| 55 |
+
"""Add district filter condition."""
|
| 56 |
+
if districts:
|
| 57 |
+
condition = rest.FieldCondition(
|
| 58 |
+
key="metadata.district",
|
| 59 |
+
match=rest.MatchAny(any=districts)
|
| 60 |
+
)
|
| 61 |
+
self.conditions.append(condition)
|
| 62 |
+
print(f"🔧 FilterBuilder: Added district filter for {districts}")
|
| 63 |
+
return self
|
| 64 |
+
|
| 65 |
+
def add_custom_filter(self, key: str, value: Union[str, List[str]]) -> 'FilterBuilder':
|
| 66 |
+
"""Add custom filter condition."""
|
| 67 |
+
if isinstance(value, list):
|
| 68 |
+
condition = rest.FieldCondition(
|
| 69 |
+
key=key,
|
| 70 |
+
match=rest.MatchAny(any=value)
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
condition = rest.FieldCondition(
|
| 74 |
+
key=key,
|
| 75 |
+
match=rest.MatchValue(value=value)
|
| 76 |
+
)
|
| 77 |
+
self.conditions.append(condition)
|
| 78 |
+
return self
|
| 79 |
+
|
| 80 |
+
def build(self) -> rest.Filter:
|
| 81 |
+
"""Build the final filter."""
|
| 82 |
+
if not self.conditions:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
return rest.Filter(must=self.conditions)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def create_filter(
|
| 89 |
+
reports: List[str] = None,
|
| 90 |
+
sources: Union[str, List[str]] = None,
|
| 91 |
+
subtype: List[str] = None,
|
| 92 |
+
year: List[str] = None,
|
| 93 |
+
district: List[str] = None,
|
| 94 |
+
filenames: List[str] = None
|
| 95 |
+
) -> rest.Filter:
|
| 96 |
+
"""
|
| 97 |
+
Create a search filter for Qdrant (legacy function for compatibility).
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
reports: List of specific report filenames
|
| 101 |
+
sources: Source category
|
| 102 |
+
subtype: List of subtypes/filenames
|
| 103 |
+
year: List of years
|
| 104 |
+
district: List of districts
|
| 105 |
+
filenames: List of specific filenames (mutually exclusive with other filters)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Qdrant Filter object
|
| 109 |
+
|
| 110 |
+
Note:
|
| 111 |
+
If filenames are provided, ONLY filename filtering is applied (mutually exclusive)
|
| 112 |
+
"""
|
| 113 |
+
builder = FilterBuilder()
|
| 114 |
+
|
| 115 |
+
# Check if filename filtering is requested (mutually exclusive)
|
| 116 |
+
# Both filenames and reports serve the same purpose (backward compatibility)
|
| 117 |
+
# Prefer filenames, fallback to reports for legacy support
|
| 118 |
+
target_filenames = filenames if filenames else reports
|
| 119 |
+
|
| 120 |
+
if target_filenames and len(target_filenames) > 0:
|
| 121 |
+
# ONLY apply filename filter, ignore all other filters
|
| 122 |
+
print(f"🔍 FILTER APPLIED: Filenames = {target_filenames} (mutually exclusive mode)")
|
| 123 |
+
builder.add_filename_filter(target_filenames)
|
| 124 |
+
else:
|
| 125 |
+
# Otherwise, filter by source and subtype
|
| 126 |
+
print(f"🔍 FILTER APPLIED: Sources = {sources}, Subtype = {subtype}, Year = {year}, District = {district}")
|
| 127 |
+
if sources:
|
| 128 |
+
print(f"✅ Adding source filter: metadata.source = '{sources}'")
|
| 129 |
+
builder.add_source_filter(sources)
|
| 130 |
+
if subtype:
|
| 131 |
+
print(f"✅ Adding subtype filter: metadata.filename IN {subtype}")
|
| 132 |
+
builder.add_filename_filter(subtype)
|
| 133 |
+
if year:
|
| 134 |
+
print(f"✅ Adding year filter: metadata.year IN {year}")
|
| 135 |
+
builder.add_year_filter(year)
|
| 136 |
+
|
| 137 |
+
if district:
|
| 138 |
+
print(f"✅ Adding district filter: metadata.district IN {district}")
|
| 139 |
+
builder.add_district_filter(district)
|
| 140 |
+
|
| 141 |
+
filter_obj = builder.build()
|
| 142 |
+
|
| 143 |
+
if filter_obj:
|
| 144 |
+
print(f"�� FINAL FILTER: {len(filter_obj.must)} condition(s) applied")
|
| 145 |
+
for i, condition in enumerate(filter_obj.must, 1):
|
| 146 |
+
print(f" Condition {i}: {condition.key} = {condition.match}")
|
| 147 |
+
else:
|
| 148 |
+
print("⚠️ NO FILTERS APPLIED: All documents will be searched")
|
| 149 |
+
|
| 150 |
+
return filter_obj
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def create_advanced_filter(
|
| 154 |
+
must_conditions: List[dict] = None,
|
| 155 |
+
should_conditions: List[dict] = None,
|
| 156 |
+
must_not_conditions: List[dict] = None
|
| 157 |
+
) -> rest.Filter:
|
| 158 |
+
"""
|
| 159 |
+
Create advanced filter with multiple condition types.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
must_conditions: Conditions that must match
|
| 163 |
+
should_conditions: Conditions that should match (OR logic)
|
| 164 |
+
must_not_conditions: Conditions that must not match
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Qdrant Filter object
|
| 168 |
+
"""
|
| 169 |
+
filter_dict = {}
|
| 170 |
+
|
| 171 |
+
if must_conditions:
|
| 172 |
+
filter_dict["must"] = [
|
| 173 |
+
_dict_to_field_condition(cond) for cond in must_conditions
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
if should_conditions:
|
| 177 |
+
filter_dict["should"] = [
|
| 178 |
+
_dict_to_field_condition(cond) for cond in should_conditions
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
if must_not_conditions:
|
| 182 |
+
filter_dict["must_not"] = [
|
| 183 |
+
_dict_to_field_condition(cond) for cond in must_not_conditions
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
if not filter_dict:
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
return rest.Filter(**filter_dict)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _dict_to_field_condition(condition_dict: dict) -> rest.FieldCondition:
|
| 193 |
+
"""Convert dictionary to FieldCondition."""
|
| 194 |
+
key = condition_dict["key"]
|
| 195 |
+
value = condition_dict["value"]
|
| 196 |
+
|
| 197 |
+
if isinstance(value, list):
|
| 198 |
+
match = rest.MatchAny(any=value)
|
| 199 |
+
else:
|
| 200 |
+
match = rest.MatchValue(value=value)
|
| 201 |
+
|
| 202 |
+
return rest.FieldCondition(key=key, match=match)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def validate_filter(filter_obj: rest.Filter) -> bool:
|
| 206 |
+
"""
|
| 207 |
+
Validate that a filter object is properly constructed.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
filter_obj: Qdrant Filter object
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
True if valid, raises ValueError if invalid
|
| 214 |
+
"""
|
| 215 |
+
if filter_obj is None:
|
| 216 |
+
return True
|
| 217 |
+
|
| 218 |
+
if not isinstance(filter_obj, rest.Filter):
|
| 219 |
+
raise ValueError("Filter must be a rest.Filter object")
|
| 220 |
+
|
| 221 |
+
# Check that at least one condition type is present
|
| 222 |
+
has_conditions = any([
|
| 223 |
+
hasattr(filter_obj, 'must') and filter_obj.must,
|
| 224 |
+
hasattr(filter_obj, 'should') and filter_obj.should,
|
| 225 |
+
hasattr(filter_obj, 'must_not') and filter_obj.must_not
|
| 226 |
+
])
|
| 227 |
+
|
| 228 |
+
if not has_conditions:
|
| 229 |
+
raise ValueError("Filter must have at least one condition")
|
| 230 |
+
|
| 231 |
+
return True
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def infer_filters_from_query(
|
| 235 |
+
query: str,
|
| 236 |
+
available_metadata: dict,
|
| 237 |
+
llm_client=None
|
| 238 |
+
) -> Tuple[rest.Filter, Union[dict, None]]:
|
| 239 |
+
"""
|
| 240 |
+
Automatically infer filters from a query using LLM analysis.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
query: User query to analyze
|
| 244 |
+
available_metadata: Available metadata values in the vectorstore
|
| 245 |
+
llm_client: LLM client for analysis (optional)
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Qdrant Filter object with inferred conditions
|
| 249 |
+
"""
|
| 250 |
+
print(f"�� AUTO-INFERRING FILTERS from query: '{query[:50]}...'")
|
| 251 |
+
|
| 252 |
+
# Check if LLM client is available
|
| 253 |
+
if not llm_client:
|
| 254 |
+
print(f"❌ LLM CLIENT MISSING: Cannot use LLM analysis, falling back to rule-based")
|
| 255 |
+
return _infer_filters_rule_based(query, available_metadata), None
|
| 256 |
+
|
| 257 |
+
# Extract available options
|
| 258 |
+
available_sources = available_metadata.get('sources', [])
|
| 259 |
+
available_years = available_metadata.get('years', [])
|
| 260 |
+
available_filenames = available_metadata.get('filenames', [])
|
| 261 |
+
|
| 262 |
+
print(f"📊 Available metadata: sources={len(available_sources)}, years={len(available_years)}, filenames={len(available_filenames)}")
|
| 263 |
+
|
| 264 |
+
# Try LLM analysis first
|
| 265 |
+
print(f" LLM ANALYSIS: Attempting LLM-based filter inference...")
|
| 266 |
+
llm_result = _analyze_query_with_llm(
|
| 267 |
+
query=query,
|
| 268 |
+
available_metadata=available_metadata,
|
| 269 |
+
llm_client=llm_client
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if llm_result:
|
| 273 |
+
print(f"✅ LLM SUCCESS: LLM successfully inferred filters")
|
| 274 |
+
# Use the _build_qdrant_filter function to properly build the Qdrant filter
|
| 275 |
+
qdrant_filter, filter_summary = _build_qdrant_filter(llm_result)
|
| 276 |
+
if qdrant_filter:
|
| 277 |
+
print(f"✅ QDRANT FILTER: Successfully built Qdrant filter")
|
| 278 |
+
# print(f"✅ INFERRED FILTERS: {qdrant_filter}")
|
| 279 |
+
return qdrant_filter, filter_summary
|
| 280 |
+
else:
|
| 281 |
+
print(f"❌ QDRANT FILTER: Failed to build Qdrant filter, trying rule-based fallback")
|
| 282 |
+
rule_based_result = _infer_filters_rule_based(query, available_metadata)
|
| 283 |
+
# Use the _build_qdrant_filter function to properly build the Qdrant filter
|
| 284 |
+
qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
|
| 285 |
+
if qdrant_filter:
|
| 286 |
+
print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
|
| 287 |
+
return qdrant_filter, filter_summary
|
| 288 |
+
else:
|
| 289 |
+
print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
|
| 290 |
+
return None, None
|
| 291 |
+
else:
|
| 292 |
+
print(f"⚠️ LLM FAILED: LLM could not infer filters, trying rule-based fallback")
|
| 293 |
+
rule_based_result = _infer_filters_rule_based(query, available_metadata)
|
| 294 |
+
# Use the _build_qdrant_filter function to properly build the Qdrant filter
|
| 295 |
+
qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
|
| 296 |
+
if qdrant_filter:
|
| 297 |
+
print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
|
| 298 |
+
return qdrant_filter, filter_summary
|
| 299 |
+
else:
|
| 300 |
+
print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
|
| 301 |
+
return None, None
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def _analyze_query_with_llm(
|
| 305 |
+
query: str,
|
| 306 |
+
available_metadata: Dict[str, List[str]],
|
| 307 |
+
llm_client=None
|
| 308 |
+
) -> dict:
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
"""
|
| 312 |
+
- Filenames: {available_metadata.get('filenames', [])}
|
| 313 |
+
|
| 314 |
+
📁 FILENAME FILTERING (Use Sparingly):
|
| 315 |
+
- Only if specific filename explicitly mentioned
|
| 316 |
+
- Prefer source/subtype over filename
|
| 317 |
+
- Be very conservative
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
"filenames": ["filename1", "filename2"] or [],
|
| 321 |
+
- For filenames: Only use if you have high confidence and can identify specific files
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
"""
|
| 326 |
+
Use LLM to analyze query and infer appropriate filters.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
query: User query to analyze
|
| 330 |
+
available_metadata: Available metadata values in the vectorstore
|
| 331 |
+
llm_client: LLM client for analysis
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Dictionary with inferred filters or empty dict if failed
|
| 335 |
+
"""
|
| 336 |
+
if not llm_client:
|
| 337 |
+
print("❌ LLM CLIENT MISSING: Cannot analyze query without LLM client")
|
| 338 |
+
return {}
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
print(f" LLM ANALYSIS: Analyzing query with LLM...")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
"""
|
| 345 |
+
For example: "What is the expected ... in 2024" - this refference to a future statement, so retrieving documents for 2023, 2022 and 2021 can be relevant too
|
| 346 |
+
Another example: "What is the GDP increase now compared to 2022" - this is a relative statement, refferring to past data, so both Year 2022, and now - 2025 needs to be detected/marked
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
# Create prompt for LLM analysis
|
| 350 |
+
prompt = f"""
|
| 351 |
+
You are a filter inference system. Analyze this query and return ONLY a JSON object.
|
| 352 |
+
|
| 353 |
+
Query: "{query}"
|
| 354 |
+
|
| 355 |
+
Available metadata:
|
| 356 |
+
- Sources: {available_metadata.get('sources', [])}
|
| 357 |
+
- Years: {available_metadata.get('years', [])}
|
| 358 |
+
|
| 359 |
+
FILTER INFERENCE GUIDELINES:
|
| 360 |
+
|
| 361 |
+
YEAR FILTERING (Be VERY Conservative):
|
| 362 |
+
✅ INFER YEARS ONLY IF:
|
| 363 |
+
- Explicit 4-digit years: "2022", "2023", "2021"
|
| 364 |
+
- Clear relative terms: "last year", "this year", "recent", "current year" (for the context, now is 2025)
|
| 365 |
+
- Temporal context: "annual report 2022", "audit for 2023"
|
| 366 |
+
- Give multiple years for complex queries.
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
❌ DO NOT INFER YEARS FOR:
|
| 370 |
+
- Vague terms: "implementation", "activities", "costs", "challenges", "issues"
|
| 371 |
+
- General concepts: "PDM", "administrative", "budget", "staff"
|
| 372 |
+
- Process descriptions: "how were", "what challenges", "management of"
|
| 373 |
+
|
| 374 |
+
🏛️ SOURCE FILTERING (Context-Based):
|
| 375 |
+
- "Ministry, Department and Agency" → Central government, ministries, departments, PS/ST
|
| 376 |
+
- "Local Government" → Districts, municipalities, local authorities, DLG
|
| 377 |
+
- "Consolidated" → Annual consolidated reports, OAG reports
|
| 378 |
+
- "Thematic" → Special studies, thematic reports
|
| 379 |
+
|
| 380 |
+
�� SUBTYPE FILTERING (Document Type):
|
| 381 |
+
- "audit" → Audit reports, reviews, examinations
|
| 382 |
+
- "report" → General reports, annual reports
|
| 383 |
+
- "guidance" → Guidelines, directives, circulars
|
| 384 |
+
|
| 385 |
+
CONFIDENCE SCORING:
|
| 386 |
+
- 0.9-1.0: Crystal clear indicators (explicit years, specific sources)
|
| 387 |
+
- 0.7-0.8: Good indicators (relative years, clear context)
|
| 388 |
+
- 0.5-0.6: Moderate indicators (some context clues)
|
| 389 |
+
- 0.0-0.4: Low confidence (vague or unclear)
|
| 390 |
+
|
| 391 |
+
EXAMPLES:
|
| 392 |
+
✅ "What challenges arose in 2022?" → years: ["2022"], confidence: 1
|
| 393 |
+
✅ "How were administrative costs managed in our government?" → sources: ["Local Government"], confidence: 0.75
|
| 394 |
+
✅ "PDM implementation guidelines from last year" → years: ["2024"], confidence: 0.9
|
| 395 |
+
❌ "What issues arose with budget execution?" → NO FILTERS, confidence: 0.2
|
| 396 |
+
❌ "How were tools related to administrative costs?" → NO FILTERS, confidence: 0.1
|
| 397 |
+
|
| 398 |
+
RESPONSE FORMAT (JSON only):
|
| 399 |
+
{{
|
| 400 |
+
"years": ["2022", "2023"] or [],
|
| 401 |
+
"sources": ["Ministry, Department and Agency", "Local Government"] or [],
|
| 402 |
+
"subtype": ["audit", "report"] or [],
|
| 403 |
+
"confidence": 0.8,
|
| 404 |
+
"reasoning": "Very brief explanation of filter choices"
|
| 405 |
+
}}
|
| 406 |
+
|
| 407 |
+
Rules:
|
| 408 |
+
- Use OR logic (SHOULD) for multiple values
|
| 409 |
+
- Prefer sources over filenames
|
| 410 |
+
- Only include years if clearly mentioned
|
| 411 |
+
- Return null for unclear fields
|
| 412 |
+
- For sources/subtypes: Include at least 3 candidates unless confidence is high and you can identify exactly one source (MUST)
|
| 413 |
+
- For years: If you want to include, then include at least 2 candidates unless confidence is high and you can identify exactly one year (MUST)
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
print(f"🔄 LLM CALL: Sending prompt to LLM...")
|
| 417 |
+
try:
|
| 418 |
+
# Try different methods to call the LLM
|
| 419 |
+
if hasattr(llm_client, 'invoke'):
|
| 420 |
+
response = llm_client.invoke(prompt)
|
| 421 |
+
elif hasattr(llm_client, 'generate'):
|
| 422 |
+
response = llm_client.generate([{"role": "user", "content": prompt}])
|
| 423 |
+
elif hasattr(llm_client, 'call'):
|
| 424 |
+
response = llm_client.call(prompt)
|
| 425 |
+
elif hasattr(llm_client, 'predict'):
|
| 426 |
+
response = llm_client.predict(prompt)
|
| 427 |
+
else:
|
| 428 |
+
# Try to call it directly
|
| 429 |
+
response = llm_client(prompt)
|
| 430 |
+
|
| 431 |
+
print(f"✅ LLM CALL SUCCESS: Received response from LLM")
|
| 432 |
+
|
| 433 |
+
# Extract content from response
|
| 434 |
+
if hasattr(response, 'content'):
|
| 435 |
+
response_content = response.content
|
| 436 |
+
elif hasattr(response, 'text'):
|
| 437 |
+
response_content = response.text
|
| 438 |
+
elif isinstance(response, str):
|
| 439 |
+
response_content = response
|
| 440 |
+
else:
|
| 441 |
+
response_content = str(response)
|
| 442 |
+
|
| 443 |
+
print(f"🔄 LLM RESPONSE: {response_content[:200]}...")
|
| 444 |
+
|
| 445 |
+
except Exception as e:
|
| 446 |
+
print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
|
| 447 |
+
return {}
|
| 448 |
+
|
| 449 |
+
# Parse JSON response
|
| 450 |
+
import json
|
| 451 |
+
import re
|
| 452 |
+
try:
|
| 453 |
+
print(f"🔄 JSON PARSING: Attempting to parse LLM response...")
|
| 454 |
+
|
| 455 |
+
# Clean the response to extract JSON from markdown
|
| 456 |
+
response_text = response_content.strip()
|
| 457 |
+
|
| 458 |
+
# Remove markdown formatting if present
|
| 459 |
+
if "```json" in response_text:
|
| 460 |
+
# Extract JSON from markdown code block
|
| 461 |
+
start_marker = "```json"
|
| 462 |
+
end_marker = "```"
|
| 463 |
+
start_idx = response_text.find(start_marker)
|
| 464 |
+
if start_idx != -1:
|
| 465 |
+
start_idx += len(start_marker)
|
| 466 |
+
end_idx = response_text.find(end_marker, start_idx)
|
| 467 |
+
if end_idx != -1:
|
| 468 |
+
response_text = response_text[start_idx:end_idx].strip()
|
| 469 |
+
|
| 470 |
+
# Try to find JSON object in the response
|
| 471 |
+
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
| 472 |
+
if json_match:
|
| 473 |
+
response_text = json_match.group(0)
|
| 474 |
+
|
| 475 |
+
print(f"🔄 JSON PARSING: Cleaned response: {response_text[:200]}...")
|
| 476 |
+
|
| 477 |
+
# Parse JSON
|
| 478 |
+
filters = json.loads(response_text)
|
| 479 |
+
print(f"✅ JSON PARSING SUCCESS: Parsed filters: {filters}")
|
| 480 |
+
|
| 481 |
+
# Validate filters
|
| 482 |
+
if not isinstance(filters, dict):
|
| 483 |
+
print(f"❌ JSON VALIDATION FAILED: Response is not a dictionary")
|
| 484 |
+
return {}
|
| 485 |
+
|
| 486 |
+
# Check if any filters were inferred
|
| 487 |
+
has_filters = any(filters.get(key) for key in ['sources', 'years', 'filenames'])
|
| 488 |
+
if not has_filters:
|
| 489 |
+
print(f"⚠️ QUERY DIFFICULT: LLM could not determine appropriate filters from query")
|
| 490 |
+
return {}
|
| 491 |
+
|
| 492 |
+
# print(f"✅ FILTER INFERENCE SUCCESS: Inferred filters: {filters}")
|
| 493 |
+
return filters
|
| 494 |
+
|
| 495 |
+
except json.JSONDecodeError as e:
|
| 496 |
+
print(f"❌ JSON PARSING FAILED: Invalid JSON format - {e}")
|
| 497 |
+
print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
|
| 498 |
+
return {}
|
| 499 |
+
except Exception as e:
|
| 500 |
+
print(f"❌ JSON PARSING FAILED: Unexpected error - {e}")
|
| 501 |
+
print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
|
| 502 |
+
return {}
|
| 503 |
+
|
| 504 |
+
except Exception as e:
|
| 505 |
+
print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
|
| 506 |
+
return {}
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _infer_filters_rule_based(
|
| 510 |
+
query: str,
|
| 511 |
+
available_metadata: dict
|
| 512 |
+
) -> dict:
|
| 513 |
+
"""
|
| 514 |
+
Rule-based fallback for filter inference with improved logic.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
query: User query
|
| 518 |
+
available_metadata: Available metadata values in the vectorstore
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
Dictionary of inferred filters
|
| 522 |
+
"""
|
| 523 |
+
print(f" RULE-BASED ANALYSIS: Starting rule-based inference for query: '{query[:50]}...'")
|
| 524 |
+
|
| 525 |
+
inferred = {}
|
| 526 |
+
query_lower = query.lower()
|
| 527 |
+
|
| 528 |
+
# SEMANTIC SOURCE INFERENCE - Use semantic understanding
|
| 529 |
+
source_matches = []
|
| 530 |
+
|
| 531 |
+
# Define semantic mappings for better source inference
|
| 532 |
+
source_keywords = {
|
| 533 |
+
'consolidated': ['consolidated', 'annual', 'oag', 'auditor general', 'government', 'financial statements', 'budget', 'expenditure', 'revenue'],
|
| 534 |
+
'military': ['military', 'defence', 'defense', 'army', 'navy', 'air force', 'security', 'defense ministry'],
|
| 535 |
+
'departmental': ['department', 'ministry', 'agency', 'authority', 'commission', 'board', 'directorate'],
|
| 536 |
+
'thematic': ['thematic', 'sector', 'program', 'project', 'initiative', 'development', 'infrastructure']
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
for source in available_metadata.get('sources', []):
|
| 540 |
+
source_lower = source.lower()
|
| 541 |
+
|
| 542 |
+
# Direct keyword match
|
| 543 |
+
if source_lower in query_lower:
|
| 544 |
+
source_matches.append(source)
|
| 545 |
+
print(f"✅ DIRECT MATCH: Found direct keyword match for '{source}'")
|
| 546 |
+
else:
|
| 547 |
+
# Semantic keyword matching
|
| 548 |
+
if source_lower in source_keywords:
|
| 549 |
+
keywords = source_keywords[source_lower]
|
| 550 |
+
matches = sum(1 for keyword in keywords if keyword in query_lower)
|
| 551 |
+
if matches >= 2: # Require at least 2 keyword matches for semantic inference
|
| 552 |
+
source_matches.append(source)
|
| 553 |
+
print(f"✅ SEMANTIC MATCH: Found {matches} semantic keywords for '{source}': {[k for k in keywords if k in query_lower]}")
|
| 554 |
+
|
| 555 |
+
if source_matches:
|
| 556 |
+
# Use SHOULD (OR logic) for multiple sources
|
| 557 |
+
inferred['sources_should'] = source_matches
|
| 558 |
+
print(f"✅ SOURCE INFERENCE: Found {len(source_matches)} sources with OR logic: {source_matches}")
|
| 559 |
+
else:
|
| 560 |
+
print("❌ SOURCE INFERENCE: No source keywords found in query")
|
| 561 |
+
|
| 562 |
+
# Infer year filters - use SHOULD (OR logic) for multiple years
|
| 563 |
+
import re
|
| 564 |
+
year_matches = []
|
| 565 |
+
for year in available_metadata.get('years', []):
|
| 566 |
+
if year in query or f"'{year}" in query:
|
| 567 |
+
year_matches.append(year)
|
| 568 |
+
|
| 569 |
+
if year_matches:
|
| 570 |
+
# Use SHOULD (OR logic) for multiple years
|
| 571 |
+
inferred['years_should'] = year_matches
|
| 572 |
+
print(f"✅ YEAR INFERENCE: Found {len(year_matches)} years with OR logic: {year_matches}")
|
| 573 |
+
else:
|
| 574 |
+
print("❌ YEAR INFERENCE: No year references found in query")
|
| 575 |
+
|
| 576 |
+
# Only infer filename filters if no year filter was found (to avoid conflicts)
|
| 577 |
+
if not year_matches:
|
| 578 |
+
filename_matches = []
|
| 579 |
+
for filename in available_metadata.get('filenames', []):
|
| 580 |
+
# Only match if multiple words from filename appear in query
|
| 581 |
+
filename_words = filename.lower().split()
|
| 582 |
+
matches = sum(1 for word in filename_words if word in query_lower)
|
| 583 |
+
if matches >= 2: # High confidence threshold
|
| 584 |
+
filename_matches.append(filename)
|
| 585 |
+
|
| 586 |
+
if filename_matches:
|
| 587 |
+
# Use SHOULD (OR logic) for multiple filenames
|
| 588 |
+
inferred['filenames_should'] = filename_matches
|
| 589 |
+
print(f"✅ FILENAME INFERENCE: Found {len(filename_matches)} filenames with OR logic: {filename_matches}")
|
| 590 |
+
else:
|
| 591 |
+
print("❌ FILENAME INFERENCE: No high-confidence filename matches found")
|
| 592 |
+
else:
|
| 593 |
+
print("ℹ️ FILENAME INFERENCE: Skipped (year filter already applied to avoid conflicts)")
|
| 594 |
+
|
| 595 |
+
print(f" RULE-BASED RESULT: {inferred}")
|
| 596 |
+
return inferred
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def _validate_inferred_filters(inferred_filters: dict) -> dict:
|
| 600 |
+
"""
|
| 601 |
+
Validate and normalize inferred filters to ensure they're in the expected format.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
inferred_filters: Raw inferred filters dictionary
|
| 605 |
+
|
| 606 |
+
Returns:
|
| 607 |
+
Validated and normalized filters dictionary
|
| 608 |
+
"""
|
| 609 |
+
if not isinstance(inferred_filters, dict):
|
| 610 |
+
print(f"⚠️ FILTER VALIDATION: Inferred filters is not a dict: {type(inferred_filters)}")
|
| 611 |
+
return {}
|
| 612 |
+
|
| 613 |
+
validated = {}
|
| 614 |
+
|
| 615 |
+
# Normalize field names and validate values
|
| 616 |
+
for field_name in ['sources', 'sources_should', 'years', 'years_should', 'filenames', 'filenames_should']:
|
| 617 |
+
if field_name in inferred_filters and inferred_filters[field_name]:
|
| 618 |
+
value = inferred_filters[field_name]
|
| 619 |
+
if isinstance(value, list) and len(value) > 0:
|
| 620 |
+
# Remove any None or empty string values
|
| 621 |
+
clean_value = [v for v in value if v is not None and str(v).strip()]
|
| 622 |
+
if clean_value:
|
| 623 |
+
validated[field_name] = clean_value
|
| 624 |
+
print(f"✅ FILTER VALIDATION: {field_name} = {clean_value}")
|
| 625 |
+
elif isinstance(value, str) and value.strip():
|
| 626 |
+
validated[field_name] = [value.strip()]
|
| 627 |
+
print(f"✅ FILTER VALIDATION: {field_name} = [{value.strip()}]")
|
| 628 |
+
|
| 629 |
+
return validated
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def _build_qdrant_filter(inferred_filters: dict) -> rest.Filter:
|
| 633 |
+
"""
|
| 634 |
+
Build Qdrant filter from inferred filters.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
inferred_filters: Dictionary with inferred filter values
|
| 638 |
+
|
| 639 |
+
Returns:
|
| 640 |
+
Qdrant Filter object
|
| 641 |
+
"""
|
| 642 |
+
try:
|
| 643 |
+
from qdrant_client.http import models as rest
|
| 644 |
+
|
| 645 |
+
# Validate and normalize the inferred filters first
|
| 646 |
+
validated_filters = _validate_inferred_filters(inferred_filters)
|
| 647 |
+
if not validated_filters:
|
| 648 |
+
print(f"⚠️ NO VALID FILTERS: All filters were invalid or empty")
|
| 649 |
+
return None, {}
|
| 650 |
+
|
| 651 |
+
conditions = []
|
| 652 |
+
filter_summary = {}
|
| 653 |
+
|
| 654 |
+
# Handle sources (use OR logic for multiple values)
|
| 655 |
+
# Support both 'sources' and 'sources_should' field names
|
| 656 |
+
source_values = None
|
| 657 |
+
if 'sources' in validated_filters and validated_filters['sources']:
|
| 658 |
+
source_values = validated_filters['sources']
|
| 659 |
+
elif 'sources_should' in validated_filters and validated_filters['sources_should']:
|
| 660 |
+
source_values = validated_filters['sources_should']
|
| 661 |
+
|
| 662 |
+
if source_values and isinstance(source_values, list) and len(source_values) > 0:
|
| 663 |
+
if len(source_values) == 1:
|
| 664 |
+
conditions.append(rest.FieldCondition(
|
| 665 |
+
key="metadata.source",
|
| 666 |
+
match=rest.MatchValue(value=source_values[0])
|
| 667 |
+
))
|
| 668 |
+
else:
|
| 669 |
+
# Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
|
| 670 |
+
conditions.append(rest.FieldCondition(
|
| 671 |
+
key="metadata.source",
|
| 672 |
+
match=rest.MatchAny(any=source_values)
|
| 673 |
+
))
|
| 674 |
+
filter_summary['sources'] = f"SHOULD: {source_values}"
|
| 675 |
+
|
| 676 |
+
# Handle years (use OR logic for multiple values)
|
| 677 |
+
# Support both 'years' and 'years_should' field names
|
| 678 |
+
year_values = None
|
| 679 |
+
if 'years' in validated_filters and validated_filters['years']:
|
| 680 |
+
year_values = validated_filters['years']
|
| 681 |
+
elif 'years_should' in validated_filters and validated_filters['years_should']:
|
| 682 |
+
year_values = validated_filters['years_should']
|
| 683 |
+
|
| 684 |
+
if year_values and isinstance(year_values, list) and len(year_values) > 0:
|
| 685 |
+
if len(year_values) == 1:
|
| 686 |
+
conditions.append(rest.FieldCondition(
|
| 687 |
+
key="metadata.year",
|
| 688 |
+
match=rest.MatchValue(value=year_values[0])
|
| 689 |
+
))
|
| 690 |
+
else:
|
| 691 |
+
# Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
|
| 692 |
+
conditions.append(rest.FieldCondition(
|
| 693 |
+
key="metadata.year",
|
| 694 |
+
match=rest.MatchAny(any=year_values)
|
| 695 |
+
))
|
| 696 |
+
filter_summary['years'] = f"SHOULD: {year_values}"
|
| 697 |
+
|
| 698 |
+
# Handle filenames (use OR logic for multiple values)
|
| 699 |
+
# Support both 'filenames' and 'filenames_should' field names
|
| 700 |
+
filename_values = None
|
| 701 |
+
if 'filenames' in validated_filters and validated_filters['filenames']:
|
| 702 |
+
filename_values = validated_filters['filenames']
|
| 703 |
+
elif 'filenames_should' in validated_filters and validated_filters['filenames_should']:
|
| 704 |
+
filename_values = validated_filters['filenames_should']
|
| 705 |
+
|
| 706 |
+
if filename_values and isinstance(filename_values, list) and len(filename_values) > 0:
|
| 707 |
+
if len(filename_values) == 1:
|
| 708 |
+
conditions.append(rest.FieldCondition(
|
| 709 |
+
key="metadata.filename",
|
| 710 |
+
match=rest.MatchValue(value=filename_values[0])
|
| 711 |
+
))
|
| 712 |
+
else:
|
| 713 |
+
# Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
|
| 714 |
+
conditions.append(rest.FieldCondition(
|
| 715 |
+
key="metadata.filename",
|
| 716 |
+
match=rest.MatchAny(any=filename_values)
|
| 717 |
+
))
|
| 718 |
+
filter_summary['filenames'] = f"SHOULD: {filename_values}"
|
| 719 |
+
|
| 720 |
+
# Build final filter
|
| 721 |
+
if conditions:
|
| 722 |
+
# Always wrap conditions in a Filter object, even for single conditions
|
| 723 |
+
result_filter = rest.Filter(must=conditions)
|
| 724 |
+
|
| 725 |
+
# Print clean filter summary
|
| 726 |
+
print(f"✅ APPLIED FILTERS: {filter_summary}")
|
| 727 |
+
return result_filter, filter_summary
|
| 728 |
+
else:
|
| 729 |
+
print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
|
| 730 |
+
return None, {}
|
| 731 |
+
|
| 732 |
+
except Exception as e:
|
| 733 |
+
print(f"❌ FILTER BUILD ERROR: {str(e)}")
|
| 734 |
+
print(f"🔍 DEBUG: Original inferred filters keys: {list(inferred_filters.keys()) if isinstance(inferred_filters, dict) else 'Not a dict'}")
|
| 735 |
+
print(f"🔍 DEBUG: Original inferred filters content: {inferred_filters}")
|
| 736 |
+
print(f"🔍 DEBUG: Validated filters keys: {list(validated_filters.keys()) if isinstance(validated_filters, dict) else 'Not a dict'}")
|
| 737 |
+
print(f"🔍 DEBUG: Validated filters content: {validated_filters}")
|
| 738 |
+
# Return a safe fallback - no filter (search all documents)
|
| 739 |
+
return None, {}
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
class MetadataCache:
|
| 743 |
+
"""Cache for vectorstore metadata to avoid repeated queries."""
|
| 744 |
+
|
| 745 |
+
def __init__(self):
|
| 746 |
+
self._cache = None
|
| 747 |
+
self._last_updated = None
|
| 748 |
+
self._cache_ttl = 3600 # 1 hour TTL
|
| 749 |
+
|
| 750 |
+
def get_metadata(self, vectorstore) -> dict:
|
| 751 |
+
"""
|
| 752 |
+
Get metadata from cache or load it if not available/expired.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
vectorstore: QdrantVectorStore instance
|
| 756 |
+
|
| 757 |
+
Returns:
|
| 758 |
+
Dictionary of available metadata values
|
| 759 |
+
"""
|
| 760 |
+
import time
|
| 761 |
+
|
| 762 |
+
# Check if cache is valid
|
| 763 |
+
if (self._cache is not None and
|
| 764 |
+
self._last_updated is not None and
|
| 765 |
+
time.time() - self._last_updated < self._cache_ttl):
|
| 766 |
+
print(f"✅ METADATA CACHE: Using cached metadata")
|
| 767 |
+
return self._cache
|
| 768 |
+
|
| 769 |
+
try:
|
| 770 |
+
print(f"🔄 METADATA CACHE: Loading metadata from vectorstore...")
|
| 771 |
+
|
| 772 |
+
# Get collection info
|
| 773 |
+
try:
|
| 774 |
+
collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
|
| 775 |
+
print(f"✅ Collection info retrieved: {getattr(collection_info, 'name', 'unknown')}")
|
| 776 |
+
except Exception as e:
|
| 777 |
+
print(f"⚠️ Could not get collection info: {e}")
|
| 778 |
+
|
| 779 |
+
# Get ALL documents to extract complete metadata
|
| 780 |
+
print(f"📄 Scanning entire corpus for complete metadata extraction...")
|
| 781 |
+
|
| 782 |
+
# Get collection info to determine total size
|
| 783 |
+
try:
|
| 784 |
+
collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
|
| 785 |
+
total_points = getattr(collection_info, 'points_count', 0)
|
| 786 |
+
print(f"📊 Total documents in corpus: {total_points}")
|
| 787 |
+
except Exception as e:
|
| 788 |
+
print(f"⚠️ Could not get collection size: {e}")
|
| 789 |
+
total_points = 0
|
| 790 |
+
|
| 791 |
+
# Extract unique metadata values from ALL documents
|
| 792 |
+
sources = set()
|
| 793 |
+
years = set()
|
| 794 |
+
filenames = set()
|
| 795 |
+
|
| 796 |
+
# Try to use scroll to get all documents in batches
|
| 797 |
+
batch_size = 1000 # Process in batches to avoid memory issues
|
| 798 |
+
offset = None
|
| 799 |
+
processed_count = 0
|
| 800 |
+
scroll_success = False
|
| 801 |
+
|
| 802 |
+
try:
|
| 803 |
+
while True:
|
| 804 |
+
# Scroll through all documents
|
| 805 |
+
scroll_result = vectorstore._client.scroll(
|
| 806 |
+
collection_name=vectorstore.collection_name,
|
| 807 |
+
limit=batch_size,
|
| 808 |
+
offset=offset,
|
| 809 |
+
with_payload=True,
|
| 810 |
+
with_vectors=False # We only need metadata
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
points = scroll_result[0] # Get the points
|
| 814 |
+
if not points:
|
| 815 |
+
break # No more documents
|
| 816 |
+
|
| 817 |
+
# Process each document
|
| 818 |
+
for i, point in enumerate(points):
|
| 819 |
+
if hasattr(point, 'payload') and point.payload:
|
| 820 |
+
payload = point.payload
|
| 821 |
+
|
| 822 |
+
# Debug: Log structure of first few documents
|
| 823 |
+
if processed_count + i < 2: # Only log first 2 documents
|
| 824 |
+
print(f"🔍 DEBUG Document {processed_count + i + 1} payload structure:")
|
| 825 |
+
print(f" Payload keys: {list(payload.keys()) if isinstance(payload, dict) else 'Not a dict'}")
|
| 826 |
+
if isinstance(payload, dict) and 'metadata' in payload:
|
| 827 |
+
print(f" Metadata keys: {list(payload['metadata'].keys()) if isinstance(payload['metadata'], dict) else 'Not a dict'}")
|
| 828 |
+
elif isinstance(payload, dict):
|
| 829 |
+
print(f" Top-level keys: {list(payload.keys())}")
|
| 830 |
+
print(f" Payload type: {type(payload)}")
|
| 831 |
+
print(f" Payload sample: {str(payload)[:200]}...")
|
| 832 |
+
print()
|
| 833 |
+
|
| 834 |
+
# Try different metadata structures
|
| 835 |
+
found_metadata = False
|
| 836 |
+
|
| 837 |
+
# Structure 1: payload['metadata']['source']
|
| 838 |
+
if isinstance(payload, dict) and 'metadata' in payload:
|
| 839 |
+
metadata = payload['metadata']
|
| 840 |
+
if isinstance(metadata, dict):
|
| 841 |
+
if 'source' in metadata:
|
| 842 |
+
sources.add(metadata['source'])
|
| 843 |
+
found_metadata = True
|
| 844 |
+
if 'year' in metadata:
|
| 845 |
+
years.add(metadata['year'])
|
| 846 |
+
found_metadata = True
|
| 847 |
+
if 'filename' in metadata:
|
| 848 |
+
filenames.add(metadata['filename'])
|
| 849 |
+
found_metadata = True
|
| 850 |
+
|
| 851 |
+
# Structure 2: payload['source'] (direct)
|
| 852 |
+
if isinstance(payload, dict):
|
| 853 |
+
if 'source' in payload:
|
| 854 |
+
sources.add(payload['source'])
|
| 855 |
+
found_metadata = True
|
| 856 |
+
if 'year' in payload:
|
| 857 |
+
years.add(payload['year'])
|
| 858 |
+
found_metadata = True
|
| 859 |
+
if 'filename' in payload:
|
| 860 |
+
filenames.add(payload['filename'])
|
| 861 |
+
found_metadata = True
|
| 862 |
+
|
| 863 |
+
# Structure 3: Check for nested structures
|
| 864 |
+
if not found_metadata and isinstance(payload, dict):
|
| 865 |
+
# Look for any nested dict that might contain metadata
|
| 866 |
+
for key, value in payload.items():
|
| 867 |
+
if isinstance(value, dict):
|
| 868 |
+
if 'source' in value:
|
| 869 |
+
sources.add(value['source'])
|
| 870 |
+
found_metadata = True
|
| 871 |
+
if 'year' in value:
|
| 872 |
+
years.add(value['year'])
|
| 873 |
+
found_metadata = True
|
| 874 |
+
if 'filename' in value:
|
| 875 |
+
filenames.add(value['filename'])
|
| 876 |
+
found_metadata = True
|
| 877 |
+
|
| 878 |
+
processed_count += len(points)
|
| 879 |
+
progress_pct = (processed_count / total_points * 100) if total_points > 0 else 0
|
| 880 |
+
print(f"📄 Processed {processed_count}/{total_points} documents ({progress_pct:.1f}%)... (sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
|
| 881 |
+
|
| 882 |
+
# Update offset for next batch
|
| 883 |
+
offset = scroll_result[1] # Next offset
|
| 884 |
+
if offset is None:
|
| 885 |
+
break # No more documents
|
| 886 |
+
|
| 887 |
+
scroll_success = True
|
| 888 |
+
print(f"✅ Scroll method successful - processed {processed_count} documents")
|
| 889 |
+
|
| 890 |
+
except Exception as e:
|
| 891 |
+
print(f"❌ Scroll method failed: {e}")
|
| 892 |
+
print(f"🔄 Falling back to similarity search method...")
|
| 893 |
+
|
| 894 |
+
# Fallback: Use similarity search with multiple queries to get more coverage
|
| 895 |
+
fallback_queries = [
|
| 896 |
+
"", # Empty query
|
| 897 |
+
"audit", "report", "government", "ministry", "department",
|
| 898 |
+
"local", "consolidated", "annual", "financial", "budget",
|
| 899 |
+
"2020", "2021", "2022", "2023", "2024" # Year queries
|
| 900 |
+
]
|
| 901 |
+
|
| 902 |
+
processed_count = 0
|
| 903 |
+
for query in fallback_queries:
|
| 904 |
+
try:
|
| 905 |
+
# Get documents for this query
|
| 906 |
+
docs = vectorstore.similarity_search(query, k=1000) # Get more per query
|
| 907 |
+
|
| 908 |
+
for j, doc in enumerate(docs):
|
| 909 |
+
if hasattr(doc, 'metadata') and doc.metadata:
|
| 910 |
+
# Debug: Log structure of first few documents in fallback
|
| 911 |
+
if processed_count + j < 3: # Only log first 3 documents per query
|
| 912 |
+
print(f"🔍 DEBUG Fallback Document {processed_count + j + 1} (query: '{query}') metadata structure:")
|
| 913 |
+
print(f" Metadata keys: {list(doc.metadata.keys()) if isinstance(doc.metadata, dict) else 'Not a dict'}")
|
| 914 |
+
print(f" Metadata type: {type(doc.metadata)}")
|
| 915 |
+
print(f" Metadata sample: {str(doc.metadata)[:200]}...")
|
| 916 |
+
print()
|
| 917 |
+
|
| 918 |
+
if 'source' in doc.metadata:
|
| 919 |
+
sources.add(doc.metadata['source'])
|
| 920 |
+
if 'year' in doc.metadata:
|
| 921 |
+
years.add(doc.metadata['year'])
|
| 922 |
+
if 'filename' in doc.metadata:
|
| 923 |
+
filenames.add(doc.metadata['filename'])
|
| 924 |
+
|
| 925 |
+
processed_count += len(docs)
|
| 926 |
+
print(f"📄 Fallback query '{query}': {len(docs)} docs (total: {processed_count}, sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
|
| 927 |
+
|
| 928 |
+
except Exception as query_error:
|
| 929 |
+
print(f"⚠️ Fallback query '{query}' failed: {query_error}")
|
| 930 |
+
continue
|
| 931 |
+
|
| 932 |
+
print(f"✅ Fallback method completed - processed {processed_count} documents")
|
| 933 |
+
|
| 934 |
+
print(f"✅ Completed scanning {processed_count} documents from entire corpus")
|
| 935 |
+
|
| 936 |
+
# Convert to sorted lists
|
| 937 |
+
metadata = {
|
| 938 |
+
'sources': sorted(list(sources)),
|
| 939 |
+
'years': sorted(list(years)),
|
| 940 |
+
'filenames': sorted(list(filenames))
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
# Cache the results
|
| 944 |
+
self._cache = metadata
|
| 945 |
+
self._last_updated = time.time()
|
| 946 |
+
|
| 947 |
+
print(f"✅ Complete metadata extracted from entire corpus: {len(sources)} sources, {len(years)} years, {len(filenames)} files")
|
| 948 |
+
|
| 949 |
+
# Debug: Show what was actually found
|
| 950 |
+
if sources:
|
| 951 |
+
print(f"📁 Sources found: {sorted(list(sources))}")
|
| 952 |
+
else:
|
| 953 |
+
print(f"❌ No sources found - check metadata structure")
|
| 954 |
+
|
| 955 |
+
if years:
|
| 956 |
+
print(f"📅 Years found: {sorted(list(years))}")
|
| 957 |
+
else:
|
| 958 |
+
print(f"❌ No years found - check metadata structure")
|
| 959 |
+
|
| 960 |
+
if filenames:
|
| 961 |
+
print(f"📄 Filenames found: {sorted(list(filenames))[:10]}{'...' if len(filenames) > 10 else ''}")
|
| 962 |
+
else:
|
| 963 |
+
print(f"❌ No filenames found - check metadata structure")
|
| 964 |
+
return metadata
|
| 965 |
+
|
| 966 |
+
except Exception as e:
|
| 967 |
+
print(f"❌ Error extracting metadata: {e}")
|
| 968 |
+
return {'sources': [], 'years': [], 'filenames': []}
|
| 969 |
+
|
| 970 |
+
# Global metadata cache
|
| 971 |
+
_metadata_cache = MetadataCache()
|
| 972 |
+
|
| 973 |
+
def get_available_metadata(vectorstore) -> dict:
|
| 974 |
+
"""Get available metadata values from the vectorstore efficiently."""
|
| 975 |
+
return _metadata_cache.get_metadata(vectorstore)
|
src/retrieval/hybrid.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hybrid search implementation combining vector and sparse retrieval."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Dict, Any, Tuple
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from langchain.docstore.document import Document
|
| 8 |
+
from langchain_qdrant import QdrantVectorStore
|
| 9 |
+
from langchain_community.retrievers import BM25Retriever
|
| 10 |
+
from .filter import create_filter
|
| 11 |
+
import pickle
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HybridRetriever:
|
| 16 |
+
"""
|
| 17 |
+
Hybrid retrieval system combining vector search (dense) and BM25 (sparse) search.
|
| 18 |
+
Supports configurable search modes: vector_only, sparse_only, or hybrid.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: Dict[str, Any]):
|
| 22 |
+
"""
|
| 23 |
+
Initialize hybrid retriever.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
config: Configuration dictionary with hybrid search settings
|
| 27 |
+
"""
|
| 28 |
+
self.config = config
|
| 29 |
+
self.bm25_retriever = None
|
| 30 |
+
self.documents = []
|
| 31 |
+
self._bm25_cache_file = None
|
| 32 |
+
|
| 33 |
+
def _get_bm25_cache_path(self) -> str:
|
| 34 |
+
"""Get path for BM25 cache file."""
|
| 35 |
+
cache_dir = Path("cache/bm25")
|
| 36 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
return str(cache_dir / "bm25_retriever.pkl")
|
| 38 |
+
|
| 39 |
+
def initialize_bm25(self, documents: List[Document], force_rebuild: bool = False) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Initialize BM25 retriever with documents.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
documents: List of Document objects to index
|
| 45 |
+
force_rebuild: Whether to force rebuilding the BM25 index
|
| 46 |
+
"""
|
| 47 |
+
self.documents = documents
|
| 48 |
+
self._bm25_cache_file = self._get_bm25_cache_path()
|
| 49 |
+
|
| 50 |
+
# Try to load cached BM25 retriever
|
| 51 |
+
if not force_rebuild and os.path.exists(self._bm25_cache_file):
|
| 52 |
+
try:
|
| 53 |
+
print("Loading cached BM25 retriever...")
|
| 54 |
+
with open(self._bm25_cache_file, 'rb') as f:
|
| 55 |
+
self.bm25_retriever = pickle.load(f)
|
| 56 |
+
print(f"✅ Loaded cached BM25 retriever with {len(self.documents)} documents")
|
| 57 |
+
return
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"⚠️ Failed to load cached BM25 retriever: {e}")
|
| 60 |
+
print("Building new BM25 index...")
|
| 61 |
+
|
| 62 |
+
# Build new BM25 retriever
|
| 63 |
+
print("Building BM25 index...")
|
| 64 |
+
try:
|
| 65 |
+
# Use langchain's BM25Retriever
|
| 66 |
+
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 67 |
+
|
| 68 |
+
# Configure BM25 parameters
|
| 69 |
+
bm25_config = self.config.get("bm25", {})
|
| 70 |
+
k = bm25_config.get("top_k", 20)
|
| 71 |
+
self.bm25_retriever.k = k
|
| 72 |
+
|
| 73 |
+
# Cache the BM25 retriever
|
| 74 |
+
with open(self._bm25_cache_file, 'wb') as f:
|
| 75 |
+
pickle.dump(self.bm25_retriever, f)
|
| 76 |
+
print(f"✅ Built and cached BM25 retriever with {len(documents)} documents")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"❌ Failed to build BM25 retriever: {e}")
|
| 80 |
+
print("BM25 search will be disabled")
|
| 81 |
+
self.bm25_retriever = None
|
| 82 |
+
|
| 83 |
+
def _filter_documents_by_metadata(
|
| 84 |
+
self,
|
| 85 |
+
documents: List[Document],
|
| 86 |
+
reports: List[str] = None,
|
| 87 |
+
sources: str = None,
|
| 88 |
+
subtype: List[str] = None,
|
| 89 |
+
year: List[str] = None
|
| 90 |
+
) -> List[Document]:
|
| 91 |
+
"""
|
| 92 |
+
Filter documents by metadata criteria.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
documents: List of documents to filter
|
| 96 |
+
reports: List of specific report filenames
|
| 97 |
+
sources: Source category
|
| 98 |
+
subtype: List of subtypes
|
| 99 |
+
year: List of years
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Filtered list of documents
|
| 103 |
+
"""
|
| 104 |
+
if not any([reports, sources, subtype, year]):
|
| 105 |
+
return documents
|
| 106 |
+
|
| 107 |
+
filtered_docs = []
|
| 108 |
+
for doc in documents:
|
| 109 |
+
metadata = doc.metadata
|
| 110 |
+
|
| 111 |
+
# Filter by reports
|
| 112 |
+
if reports:
|
| 113 |
+
filename = metadata.get('filename', '')
|
| 114 |
+
if not any(report in filename for report in reports):
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Filter by sources
|
| 118 |
+
if sources:
|
| 119 |
+
doc_source = metadata.get('source', '')
|
| 120 |
+
if sources != doc_source:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# Filter by subtype
|
| 124 |
+
if subtype:
|
| 125 |
+
doc_subtype = metadata.get('subtype', '')
|
| 126 |
+
if doc_subtype not in subtype:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
# Filter by year
|
| 130 |
+
if year:
|
| 131 |
+
doc_year = str(metadata.get('year', ''))
|
| 132 |
+
if doc_year not in year:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
filtered_docs.append(doc)
|
| 136 |
+
|
| 137 |
+
return filtered_docs
|
| 138 |
+
|
| 139 |
+
def _bm25_search(
|
| 140 |
+
self,
|
| 141 |
+
query: str,
|
| 142 |
+
k: int = 20,
|
| 143 |
+
reports: List[str] = None,
|
| 144 |
+
sources: str = None,
|
| 145 |
+
subtype: List[str] = None,
|
| 146 |
+
year: List[str] = None
|
| 147 |
+
) -> List[Tuple[Document, float]]:
|
| 148 |
+
"""
|
| 149 |
+
Perform BM25 sparse search.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
query: Search query
|
| 153 |
+
k: Number of documents to retrieve
|
| 154 |
+
reports: List of specific report filenames
|
| 155 |
+
sources: Source category
|
| 156 |
+
subtype: List of subtypes
|
| 157 |
+
year: List of years
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
List of (Document, score) tuples
|
| 161 |
+
"""
|
| 162 |
+
if not self.bm25_retriever:
|
| 163 |
+
print("⚠️ BM25 retriever not available")
|
| 164 |
+
return []
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
# Get BM25 results
|
| 168 |
+
self.bm25_retriever.k = k
|
| 169 |
+
bm25_docs = self.bm25_retriever.invoke(query)
|
| 170 |
+
|
| 171 |
+
# Apply metadata filtering
|
| 172 |
+
if any([reports, sources, subtype, year]):
|
| 173 |
+
bm25_docs = self._filter_documents_by_metadata(
|
| 174 |
+
bm25_docs, reports, sources, subtype, year
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# BM25Retriever doesn't return scores directly, so we'll use placeholder scores
|
| 178 |
+
# In a production system, you'd want to access the actual BM25 scores
|
| 179 |
+
results = []
|
| 180 |
+
for i, doc in enumerate(bm25_docs):
|
| 181 |
+
# Assign decreasing scores based on rank (higher rank = higher score)
|
| 182 |
+
# Normalize to [0, 1] range for consistency with vector search
|
| 183 |
+
score = max(0.1, 1.0 - (i / max(len(bm25_docs), 1)))
|
| 184 |
+
results.append((doc, score))
|
| 185 |
+
|
| 186 |
+
return results
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"❌ BM25 search failed: {e}")
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
def _vector_search(
|
| 193 |
+
self,
|
| 194 |
+
vectorstore: QdrantVectorStore,
|
| 195 |
+
query: str,
|
| 196 |
+
k: int = 20,
|
| 197 |
+
reports: List[str] = None,
|
| 198 |
+
sources: str = None,
|
| 199 |
+
subtype: List[str] = None,
|
| 200 |
+
year: List[str] = None
|
| 201 |
+
) -> List[Tuple[Document, float]]:
|
| 202 |
+
"""
|
| 203 |
+
Perform vector similarity search.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
vectorstore: QdrantVectorStore instance
|
| 207 |
+
query: Search query
|
| 208 |
+
k: Number of documents to retrieve
|
| 209 |
+
reports: List of specific report filenames
|
| 210 |
+
sources: Source category
|
| 211 |
+
subtype: List of subtypes
|
| 212 |
+
year: List of years
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
List of (Document, score) tuples
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
# Create filter
|
| 219 |
+
filter_obj = create_filter(
|
| 220 |
+
reports=reports,
|
| 221 |
+
sources=sources,
|
| 222 |
+
subtype=subtype,
|
| 223 |
+
year=year
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Perform vector search
|
| 227 |
+
if filter_obj:
|
| 228 |
+
results = vectorstore.similarity_search_with_score(
|
| 229 |
+
query, k=k, filter=filter_obj
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
results = vectorstore.similarity_search_with_score(query, k=k)
|
| 233 |
+
|
| 234 |
+
return results
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"❌ Vector search failed: {e}")
|
| 238 |
+
return []
|
| 239 |
+
|
| 240 |
+
def _normalize_scores(self, results: List[Tuple[Document, float]], method: str = "min_max") -> List[Tuple[Document, float]]:
|
| 241 |
+
"""
|
| 242 |
+
Normalize scores to [0, 1] range.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
results: List of (Document, score) tuples
|
| 246 |
+
method: Normalization method ('min_max' or 'z_score')
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
List of (Document, normalized_score) tuples
|
| 250 |
+
"""
|
| 251 |
+
if not results:
|
| 252 |
+
return results
|
| 253 |
+
|
| 254 |
+
scores = [score for _, score in results]
|
| 255 |
+
|
| 256 |
+
if method == "min_max":
|
| 257 |
+
min_score = min(scores)
|
| 258 |
+
max_score = max(scores)
|
| 259 |
+
if max_score == min_score:
|
| 260 |
+
normalized_results = [(doc, 1.0) for doc, _ in results]
|
| 261 |
+
else:
|
| 262 |
+
normalized_results = [
|
| 263 |
+
(doc, (score - min_score) / (max_score - min_score))
|
| 264 |
+
for doc, score in results
|
| 265 |
+
]
|
| 266 |
+
elif method == "z_score":
|
| 267 |
+
mean_score = np.mean(scores)
|
| 268 |
+
std_score = np.std(scores)
|
| 269 |
+
if std_score == 0:
|
| 270 |
+
normalized_results = [(doc, 1.0) for doc, _ in results]
|
| 271 |
+
else:
|
| 272 |
+
normalized_results = [
|
| 273 |
+
(doc, max(0, (score - mean_score) / std_score))
|
| 274 |
+
for doc, score in results
|
| 275 |
+
]
|
| 276 |
+
else:
|
| 277 |
+
normalized_results = results
|
| 278 |
+
|
| 279 |
+
return normalized_results
|
| 280 |
+
|
| 281 |
+
def _combine_results(
|
| 282 |
+
self,
|
| 283 |
+
vector_results: List[Tuple[Document, float]],
|
| 284 |
+
bm25_results: List[Tuple[Document, float]],
|
| 285 |
+
alpha: float = 0.5
|
| 286 |
+
) -> List[Tuple[Document, float]]:
|
| 287 |
+
"""
|
| 288 |
+
Combine vector and BM25 results with weighted scoring.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
vector_results: Vector search results
|
| 292 |
+
bm25_results: BM25 search results
|
| 293 |
+
alpha: Weight for vector scores (1-alpha for BM25 scores)
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
Combined and ranked results
|
| 297 |
+
"""
|
| 298 |
+
# Normalize scores
|
| 299 |
+
vector_results = self._normalize_scores(vector_results)
|
| 300 |
+
bm25_results = self._normalize_scores(bm25_results)
|
| 301 |
+
|
| 302 |
+
# Create document ID mapping for both result sets
|
| 303 |
+
vector_docs = {id(doc): (doc, score) for doc, score in vector_results}
|
| 304 |
+
bm25_docs = {id(doc): (doc, score) for doc, score in bm25_results}
|
| 305 |
+
|
| 306 |
+
# Combine scores
|
| 307 |
+
combined_scores = {}
|
| 308 |
+
all_doc_ids = set(vector_docs.keys()) | set(bm25_docs.keys())
|
| 309 |
+
|
| 310 |
+
for doc_id in all_doc_ids:
|
| 311 |
+
vector_score = vector_docs.get(doc_id, (None, 0.0))[1]
|
| 312 |
+
bm25_score = bm25_docs.get(doc_id, (None, 0.0))[1]
|
| 313 |
+
|
| 314 |
+
# Weighted combination
|
| 315 |
+
combined_score = alpha * vector_score + (1 - alpha) * bm25_score
|
| 316 |
+
|
| 317 |
+
# Get document object
|
| 318 |
+
doc = vector_docs.get(doc_id, bm25_docs.get(doc_id))[0]
|
| 319 |
+
combined_scores[doc_id] = (doc, combined_score)
|
| 320 |
+
|
| 321 |
+
# Sort by combined score (descending)
|
| 322 |
+
sorted_results = sorted(
|
| 323 |
+
combined_scores.values(),
|
| 324 |
+
key=lambda x: x[1],
|
| 325 |
+
reverse=True
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
return sorted_results
|
| 329 |
+
|
| 330 |
+
def retrieve(
|
| 331 |
+
self,
|
| 332 |
+
vectorstore: QdrantVectorStore,
|
| 333 |
+
query: str,
|
| 334 |
+
mode: str = "hybrid",
|
| 335 |
+
reports: List[str] = None,
|
| 336 |
+
sources: str = None,
|
| 337 |
+
subtype: List[str] = None,
|
| 338 |
+
year: List[str] = None,
|
| 339 |
+
alpha: float = 0.5,
|
| 340 |
+
k: int = None
|
| 341 |
+
) -> List[Document]:
|
| 342 |
+
"""
|
| 343 |
+
Retrieve documents using the specified search mode.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
vectorstore: QdrantVectorStore instance
|
| 347 |
+
query: Search query
|
| 348 |
+
mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
|
| 349 |
+
reports: List of specific report filenames
|
| 350 |
+
sources: Source category
|
| 351 |
+
subtype: List of subtypes
|
| 352 |
+
year: List of years
|
| 353 |
+
alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
|
| 354 |
+
k: Number of documents to retrieve
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
List of relevant Document objects
|
| 358 |
+
"""
|
| 359 |
+
if k is None:
|
| 360 |
+
k = self.config.get("retriever", {}).get("top_k", 20)
|
| 361 |
+
|
| 362 |
+
results = []
|
| 363 |
+
|
| 364 |
+
if mode == "vector_only":
|
| 365 |
+
# Vector search only
|
| 366 |
+
vector_results = self._vector_search(
|
| 367 |
+
vectorstore, query, k, reports, sources, subtype, year
|
| 368 |
+
)
|
| 369 |
+
results = [(doc, score) for doc, score in vector_results]
|
| 370 |
+
|
| 371 |
+
elif mode == "sparse_only":
|
| 372 |
+
# BM25 search only
|
| 373 |
+
bm25_results = self._bm25_search(
|
| 374 |
+
query, k, reports, sources, subtype, year
|
| 375 |
+
)
|
| 376 |
+
results = [(doc, score) for doc, score in bm25_results]
|
| 377 |
+
|
| 378 |
+
elif mode == "hybrid":
|
| 379 |
+
# Hybrid search - combine both
|
| 380 |
+
# Get more results from each method to have better fusion
|
| 381 |
+
retrieval_k = min(k * 2, 50) # Get more candidates for fusion
|
| 382 |
+
|
| 383 |
+
vector_results = self._vector_search(
|
| 384 |
+
vectorstore, query, retrieval_k, reports, sources, subtype, year
|
| 385 |
+
)
|
| 386 |
+
bm25_results = self._bm25_search(
|
| 387 |
+
query, retrieval_k, reports, sources, subtype, year
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
results = self._combine_results(vector_results, bm25_results, alpha)
|
| 391 |
+
|
| 392 |
+
else:
|
| 393 |
+
raise ValueError(f"Unknown search mode: {mode}")
|
| 394 |
+
|
| 395 |
+
# Limit to top k results
|
| 396 |
+
results = results[:k]
|
| 397 |
+
|
| 398 |
+
# Return just the documents
|
| 399 |
+
return [doc for doc, score in results]
|
| 400 |
+
|
| 401 |
+
def retrieve_with_scores(
|
| 402 |
+
self,
|
| 403 |
+
vectorstore: QdrantVectorStore,
|
| 404 |
+
query: str,
|
| 405 |
+
mode: str = "hybrid",
|
| 406 |
+
reports: List[str] = None,
|
| 407 |
+
sources: str = None,
|
| 408 |
+
subtype: List[str] = None,
|
| 409 |
+
year: List[str] = None,
|
| 410 |
+
alpha: float = 0.5,
|
| 411 |
+
k: int = None
|
| 412 |
+
) -> List[Tuple[Document, float]]:
|
| 413 |
+
"""
|
| 414 |
+
Retrieve documents with scores using the specified search mode.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
vectorstore: QdrantVectorStore instance
|
| 418 |
+
query: Search query
|
| 419 |
+
mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
|
| 420 |
+
reports: List of specific report filenames
|
| 421 |
+
sources: Source category
|
| 422 |
+
subtype: List of subtypes
|
| 423 |
+
year: List of years
|
| 424 |
+
alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
|
| 425 |
+
k: Number of documents to retrieve
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
List of (Document, score) tuples
|
| 429 |
+
"""
|
| 430 |
+
if k is None:
|
| 431 |
+
k = self.config.get("retriever", {}).get("top_k", 20)
|
| 432 |
+
|
| 433 |
+
results = []
|
| 434 |
+
|
| 435 |
+
if mode == "vector_only":
|
| 436 |
+
# Vector search only
|
| 437 |
+
results = self._vector_search(
|
| 438 |
+
vectorstore, query, k, reports, sources, subtype, year
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
elif mode == "sparse_only":
|
| 442 |
+
# BM25 search only
|
| 443 |
+
results = self._bm25_search(
|
| 444 |
+
query, k, reports, sources, subtype, year
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
elif mode == "hybrid":
|
| 448 |
+
# Hybrid search - combine both
|
| 449 |
+
# Get more results from each method to have better fusion
|
| 450 |
+
retrieval_k = min(k * 2, 50) # Get more candidates for fusion
|
| 451 |
+
|
| 452 |
+
vector_results = self._vector_search(
|
| 453 |
+
vectorstore, query, retrieval_k, reports, sources, subtype, year
|
| 454 |
+
)
|
| 455 |
+
bm25_results = self._bm25_search(
|
| 456 |
+
query, retrieval_k, reports, sources, subtype, year
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
results = self._combine_results(vector_results, bm25_results, alpha)
|
| 460 |
+
|
| 461 |
+
else:
|
| 462 |
+
raise ValueError(f"Unknown search mode: {mode}")
|
| 463 |
+
|
| 464 |
+
# Limit to top k results
|
| 465 |
+
return results[:k]
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def get_available_search_modes() -> List[str]:
|
| 469 |
+
"""Get list of available search modes."""
|
| 470 |
+
return ["vector_only", "sparse_only", "hybrid"]
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def get_search_mode_description() -> Dict[str, str]:
|
| 474 |
+
"""Get descriptions for each search mode."""
|
| 475 |
+
return {
|
| 476 |
+
"vector_only": "Semantic search using dense embeddings - good for conceptual matching",
|
| 477 |
+
"sparse_only": "Keyword search using BM25 - good for exact term matching",
|
| 478 |
+
"hybrid": "Combined semantic and keyword search - balanced approach"
|
| 479 |
+
}
|
src/vectorstore.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector store management and operations."""
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, Any, List, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from langchain_qdrant import QdrantVectorStore
|
| 8 |
+
from langchain.docstore.document import Document
|
| 9 |
+
from langchain_core.embeddings import Embeddings
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MatryoshkaEmbeddings(Embeddings):
|
| 15 |
+
"""Custom embeddings class that supports Matryoshka dimension truncation."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_name: str, truncate_dim: int = None, **kwargs):
|
| 18 |
+
"""
|
| 19 |
+
Initialize Matryoshka embeddings.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_name: Name of the model
|
| 23 |
+
truncate_dim: Dimension to truncate to (for Matryoshka models)
|
| 24 |
+
**kwargs: Additional arguments (ignored for Matryoshka models)
|
| 25 |
+
"""
|
| 26 |
+
self.model_name = model_name
|
| 27 |
+
self.truncate_dim = truncate_dim
|
| 28 |
+
|
| 29 |
+
if truncate_dim and "matryoshka" in model_name.lower():
|
| 30 |
+
# Use SentenceTransformer directly for Matryoshka models
|
| 31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
|
| 33 |
+
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
|
| 34 |
+
else:
|
| 35 |
+
# Use standard HuggingFaceEmbeddings
|
| 36 |
+
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 37 |
+
|
| 38 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 39 |
+
"""Embed documents."""
|
| 40 |
+
if self.truncate_dim and "matryoshka" in self.model_name.lower():
|
| 41 |
+
embeddings = self.model.encode(texts, normalize_embeddings=True)
|
| 42 |
+
return embeddings.tolist()
|
| 43 |
+
else:
|
| 44 |
+
return self.model.embed_documents(texts)
|
| 45 |
+
|
| 46 |
+
def embed_query(self, text: str) -> List[float]:
|
| 47 |
+
"""Embed query."""
|
| 48 |
+
if self.truncate_dim and "matryoshka" in self.model_name.lower():
|
| 49 |
+
embedding = self.model.encode([text], normalize_embeddings=True)
|
| 50 |
+
return embedding[0].tolist()
|
| 51 |
+
else:
|
| 52 |
+
return self.model.embed_query(text)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class VectorStoreManager:
|
| 56 |
+
"""Manages vector store operations and connections."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, config: Dict[str, Any]):
|
| 59 |
+
"""
|
| 60 |
+
Initialize vector store manager.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
config: Configuration dictionary
|
| 64 |
+
"""
|
| 65 |
+
self.config = config
|
| 66 |
+
self.embeddings = self._create_embeddings()
|
| 67 |
+
self.vectorstore = None
|
| 68 |
+
|
| 69 |
+
# Define metadata fields that need payload indexes for filtering
|
| 70 |
+
self.metadata_fields = [
|
| 71 |
+
("metadata.year", "keyword"),
|
| 72 |
+
("metadata.source", "keyword"),
|
| 73 |
+
("metadata.filename", "keyword"),
|
| 74 |
+
# Add more metadata fields as needed
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 78 |
+
"""Create embeddings model from configuration."""
|
| 79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
+
|
| 81 |
+
model_name = self.config["retriever"]["model"]
|
| 82 |
+
normalize = self.config["retriever"]["normalize"]
|
| 83 |
+
|
| 84 |
+
model_kwargs = {"device": device}
|
| 85 |
+
encode_kwargs = {
|
| 86 |
+
"normalize_embeddings": normalize,
|
| 87 |
+
"batch_size": 100,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# For Matryoshka models, check if we need to truncate dimensions
|
| 91 |
+
if "matryoshka" in model_name.lower():
|
| 92 |
+
# Check if we have a specific dimension requirement
|
| 93 |
+
collection_name = self.config.get("qdrant", {}).get("collection_name", "")
|
| 94 |
+
|
| 95 |
+
if "modernbert-embed-base-akryl-matryoshka" in collection_name:
|
| 96 |
+
# This collection expects 768 dimensions
|
| 97 |
+
truncate_dim = 768
|
| 98 |
+
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
|
| 99 |
+
|
| 100 |
+
# Use custom MatryoshkaEmbeddings
|
| 101 |
+
embeddings = MatryoshkaEmbeddings(
|
| 102 |
+
model_name=model_name,
|
| 103 |
+
truncate_dim=truncate_dim,
|
| 104 |
+
model_kwargs=model_kwargs,
|
| 105 |
+
encode_kwargs=encode_kwargs,
|
| 106 |
+
show_progress=True,
|
| 107 |
+
)
|
| 108 |
+
return embeddings
|
| 109 |
+
|
| 110 |
+
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
| 111 |
+
embeddings = HuggingFaceEmbeddings(
|
| 112 |
+
model_name=model_name,
|
| 113 |
+
model_kwargs=model_kwargs,
|
| 114 |
+
encode_kwargs=encode_kwargs,
|
| 115 |
+
show_progress=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return embeddings
|
| 119 |
+
|
| 120 |
+
def ensure_metadata_indexes(self) -> None:
|
| 121 |
+
"""
|
| 122 |
+
Create payload indexes for all required metadata fields.
|
| 123 |
+
This ensures filtering works properly, especially in Qdrant Cloud.
|
| 124 |
+
"""
|
| 125 |
+
if not self.vectorstore:
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
qdrant_config = self.config["qdrant"]
|
| 129 |
+
collection_name = qdrant_config["collection_name"]
|
| 130 |
+
|
| 131 |
+
for field_name, field_type in self.metadata_fields:
|
| 132 |
+
try:
|
| 133 |
+
self.vectorstore.client.create_payload_index(
|
| 134 |
+
collection_name=collection_name,
|
| 135 |
+
field_name=field_name,
|
| 136 |
+
field_type=field_type
|
| 137 |
+
)
|
| 138 |
+
print(f"Created payload index for {field_name} ({field_type})")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
# Index might already exist or other error - log but continue
|
| 141 |
+
print(f"Index creation for {field_name} ({field_type}): {str(e)}")
|
| 142 |
+
|
| 143 |
+
def connect_to_existing(self, force_recreate: bool = False) -> QdrantVectorStore:
|
| 144 |
+
"""
|
| 145 |
+
Connect to existing Qdrant collection.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
force_recreate: If True, recreate the collection if dimension mismatch occurs
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
QdrantVectorStore instance
|
| 152 |
+
"""
|
| 153 |
+
qdrant_config = self.config["qdrant"]
|
| 154 |
+
|
| 155 |
+
kwargs_qdrant = {
|
| 156 |
+
"url": qdrant_config["url"],
|
| 157 |
+
"collection_name": qdrant_config["collection_name"],
|
| 158 |
+
"prefer_grpc": qdrant_config.get("prefer_grpc", True),
|
| 159 |
+
"api_key": qdrant_config.get("api_key", None),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if force_recreate:
|
| 163 |
+
kwargs_qdrant["force_recreate"] = True
|
| 164 |
+
|
| 165 |
+
self.vectorstore = QdrantVectorStore.from_existing_collection(
|
| 166 |
+
embedding=self.embeddings,
|
| 167 |
+
**kwargs_qdrant
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Ensure payload indexes exist for metadata filtering
|
| 171 |
+
self.ensure_metadata_indexes()
|
| 172 |
+
|
| 173 |
+
return self.vectorstore
|
| 174 |
+
|
| 175 |
+
def create_from_documents(self, documents: List[Document]) -> QdrantVectorStore:
|
| 176 |
+
"""
|
| 177 |
+
Create new Qdrant collection from documents.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
documents: List of Document objects
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
QdrantVectorStore instance
|
| 184 |
+
"""
|
| 185 |
+
qdrant_config = self.config["qdrant"]
|
| 186 |
+
|
| 187 |
+
kwargs_qdrant = {
|
| 188 |
+
"url": qdrant_config["url"],
|
| 189 |
+
"collection_name": qdrant_config["collection_name"],
|
| 190 |
+
"prefer_grpc": qdrant_config.get("prefer_grpc", True),
|
| 191 |
+
"api_key": qdrant_config.get("api_key", None),
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
self.vectorstore = QdrantVectorStore.from_documents(
|
| 195 |
+
documents=documents,
|
| 196 |
+
embedding=self.embeddings,
|
| 197 |
+
**kwargs_qdrant
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Ensure payload indexes exist for metadata filtering
|
| 201 |
+
self.ensure_metadata_indexes()
|
| 202 |
+
|
| 203 |
+
return self.vectorstore
|
| 204 |
+
|
| 205 |
+
def delete_collection(self) -> None:
|
| 206 |
+
"""
|
| 207 |
+
Delete the current Qdrant collection.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
QdrantVectorStore instance
|
| 211 |
+
"""
|
| 212 |
+
qdrant_config = self.config["qdrant"]
|
| 213 |
+
collection_name = qdrant_config.get("collection_name")
|
| 214 |
+
|
| 215 |
+
self.vectorstore.client.delete_collection(
|
| 216 |
+
collection_name=collection_name
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return self.vectorstore
|
| 220 |
+
|
| 221 |
+
def get_vectorstore(self) -> Optional[QdrantVectorStore]:
|
| 222 |
+
"""Get current vectorstore instance."""
|
| 223 |
+
return self.vectorstore
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_local_qdrant(config: Dict[str, Any]) -> QdrantVectorStore:
|
| 227 |
+
"""
|
| 228 |
+
Get local Qdrant vector store (legacy function for compatibility).
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
config: Configuration dictionary
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
QdrantVectorStore instance
|
| 235 |
+
"""
|
| 236 |
+
manager = VectorStoreManager(config)
|
| 237 |
+
return manager.connect_to_existing()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def create_vectorstore(config: Dict[str, Any], documents: List[Document]) -> QdrantVectorStore:
|
| 241 |
+
"""
|
| 242 |
+
Create new vector store from documents.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
config: Configuration dictionary
|
| 246 |
+
documents: List of Document objects
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
QdrantVectorStore instance
|
| 250 |
+
"""
|
| 251 |
+
manager = VectorStoreManager(config)
|
| 252 |
+
return manager.create_from_documents(documents)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def get_embeddings_model(config: Dict[str, Any]) -> HuggingFaceEmbeddings:
|
| 256 |
+
"""
|
| 257 |
+
Create embeddings model from configuration (legacy function).
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
config: Configuration dictionary
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
HuggingFaceEmbeddings instance
|
| 264 |
+
"""
|
| 265 |
+
manager = VectorStoreManager(config)
|
| 266 |
+
return manager.embeddings
|