Ara Yeroyan commited on
Commit
f5df983
·
1 Parent(s): 26449fc
src/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audit QA Refactored Module
3
+ A modular and maintainable RAG pipeline for audit report analysis.
4
+ """
5
+
6
+ from .pipeline import PipelineManager
7
+ from .config.loader import load_config
8
+
9
+ __version__ = "2.0.0"
10
+ __all__ = ["PipelineManager", "load_config"]
src/config/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Configuration management for Audit QA."""
2
+
3
+ from .loader import load_config, get_nested_config
4
+
5
+ __all__ = ["load_config", "get_nested_config"]
src/config/collections.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "docling": {
3
+ "model": "BAAI/bge-m3",
4
+ "description": "Default collection with BGE-M3 embedding model"
5
+ },
6
+ "modernbert-embed-base-akryl-matryoshka": {
7
+ "model": "Akryl/modernbert-embed-base-akryl-matryoshka",
8
+ "description": "ModernBERT embedding model with matryoshka representation"
9
+ },
10
+ "sentence-transformers-all-MiniLM-L6-v2": {
11
+ "model": "sentence-transformers/all-MiniLM-L6-v2",
12
+ "description": "Sentence transformers MiniLM model"
13
+ },
14
+ "sentence-transformers-all-mpnet-base-v2": {
15
+ "model": "sentence-transformers/all-mpnet-base-v2",
16
+ "description": "Sentence transformers MPNet model"
17
+ },
18
+ "BAAI-bge-m3": {
19
+ "model": "BAAI/bge-m3",
20
+ "description": "BAAI BGE-M3 multilingual embedding model"
21
+ }
22
+ }
src/config/loader.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration loader for YAML settings."""
2
+
3
+ import yaml
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional
7
+ from dotenv import load_dotenv
8
+ import os
9
+
10
+ load_dotenv()
11
+
12
+ def load_config(config_path: str = None) -> Dict[str, Any]:
13
+ """
14
+ Load configuration from YAML file.
15
+
16
+ Args:
17
+ config_path: Path to config file. If None, uses default settings.yaml
18
+
19
+ Returns:
20
+ Dictionary containing configuration settings
21
+ """
22
+ if config_path is None:
23
+ # Default to settings.yaml in the same directory as this file
24
+ config_path = Path(__file__).parent / "settings.yaml"
25
+
26
+ config_path = Path(config_path)
27
+
28
+ if not config_path.exists():
29
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
30
+
31
+ with open(config_path, 'r', encoding='utf-8') as f:
32
+ content = f.read()
33
+
34
+ # Replace environment variables in the content
35
+ import os
36
+ import re
37
+
38
+ def replace_env_vars(match):
39
+ env_var = match.group(1)
40
+ return os.getenv(env_var, match.group(0)) # Return original if env var not found
41
+
42
+ # Replace ${VAR} patterns with environment variables
43
+ content = re.sub(r'\$\{([^}]+)\}', replace_env_vars, content)
44
+
45
+ config = yaml.safe_load(content)
46
+
47
+ # Override with environment variables if they exist
48
+ config = _override_with_env_vars(config)
49
+
50
+ return config
51
+
52
+
53
+ def _override_with_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
54
+ """Override config values with environment variables where available."""
55
+
56
+ # Map environment variables to config paths
57
+ env_mappings = {
58
+ 'QDRANT_URL': ['qdrant', 'url'],
59
+ 'QDRANT_COLLECTION': ['qdrant', 'collection_name'],
60
+ 'QDRANT_API_KEY': ['qdrant', 'api_key'],
61
+ 'RETRIEVER_MODEL': ['retriever', 'model'],
62
+ 'RANKER_MODEL': ['ranker', 'model'],
63
+ 'READER_TYPE': ['reader', 'default_type'],
64
+ 'MAX_TOKENS': ['reader', 'max_tokens'],
65
+ 'MISTRAL_API_KEY': ['reader', 'MISTRAL', 'api_key'],
66
+ 'OPENAI_API_KEY': ['reader', 'OPENAI', 'api_key'],
67
+ 'NEBIUS_API_KEY': ['reader', 'INF_PROVIDERS', 'api_key'],
68
+ 'NVIDIA_SERVER_API_KEY': ['reader', 'NVIDIA', 'api_key'],
69
+ 'SERVERLESS_API_KEY': ['reader', 'SERVERLESS', 'api_key'],
70
+ 'DEDICATED_API_KEY': ['reader', 'DEDICATED', 'api_key'],
71
+ 'OPENROUTER_API_KEY': ['reader', 'OPENROUTER', 'api_key'],
72
+ }
73
+
74
+ for env_var, config_path in env_mappings.items():
75
+ env_value = os.getenv(env_var)
76
+ if env_value:
77
+ # Navigate to the nested config location
78
+ current = config
79
+ for key in config_path[:-1]:
80
+ if key not in current:
81
+ current[key] = {}
82
+ current = current[key]
83
+
84
+ # Set the final value, converting to appropriate type
85
+ final_key = config_path[-1]
86
+ if final_key in ['top_k', 'max_tokens', 'num_predict']:
87
+ current[final_key] = int(env_value)
88
+ elif final_key in ['normalize', 'prefer_grpc']:
89
+ current[final_key] = env_value.lower() in ('true', '1', 'yes')
90
+ elif final_key == 'temperature':
91
+ current[final_key] = float(env_value)
92
+ else:
93
+ current[final_key] = env_value
94
+
95
+ return config
96
+
97
+
98
+ def get_nested_config(config: Dict[str, Any], path: str, default=None):
99
+ """
100
+ Get a nested configuration value using dot notation.
101
+
102
+ Args:
103
+ config: Configuration dictionary
104
+ path: Dot-separated path (e.g., 'reader.MISTRAL.model')
105
+ default: Default value if path not found
106
+
107
+ Returns:
108
+ Configuration value or default
109
+ """
110
+ keys = path.split('.')
111
+ current = config
112
+
113
+ try:
114
+ for key in keys:
115
+ current = current[key]
116
+ return current
117
+ except (KeyError, TypeError):
118
+ return default
119
+
120
+
121
+ def load_collections_mapping() -> Dict[str, Dict[str, str]]:
122
+ """Load collections mapping from JSON file."""
123
+ collections_file = Path(__file__).parent / "collections.json"
124
+
125
+ if not collections_file.exists():
126
+ # Return default mapping if file doesn't exist
127
+ return {
128
+ "docling": {
129
+ "model": "sentence-transformers/all-MiniLM-L6-v2",
130
+ "description": "Default collection"
131
+ }
132
+ }
133
+
134
+ with open(collections_file, 'r') as f:
135
+ return json.load(f)
136
+
137
+
138
+ def get_embedding_model_for_collection(collection_name: str) -> Optional[str]:
139
+ """Get embedding model for a specific collection name."""
140
+ collections = load_collections_mapping()
141
+
142
+ if collection_name in collections:
143
+ return collections[collection_name]["model"]
144
+
145
+ # Try to infer from collection name patterns
146
+ if "modernbert" in collection_name.lower():
147
+ return "Akryl/modernbert-embed-base-akryl-matryoshka"
148
+ elif "minilm" in collection_name.lower():
149
+ return "sentence-transformers/all-MiniLM-L6-v2"
150
+ elif "mpnet" in collection_name.lower():
151
+ return "sentence-transformers/all-mpnet-base-v2"
152
+ elif "bge" in collection_name.lower():
153
+ return "BAAI/bge-m3"
154
+
155
+ return None
156
+
157
+
158
+ def get_collection_info(collection_name: str) -> Dict[str, str]:
159
+ """Get full collection information including model and description."""
160
+ collections = load_collections_mapping()
161
+
162
+ if collection_name in collections:
163
+ return collections[collection_name]
164
+
165
+ # Return inferred info for unknown collections
166
+ model = get_embedding_model_for_collection(collection_name)
167
+ return {
168
+ "model": model or "unknown",
169
+ "description": f"Auto-inferred collection: {collection_name}"
170
+ }
src/config/settings.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audit QA Configuration
2
+ # Converted from model_params.cfg to YAML format
3
+
4
+ qdrant:
5
+ # url: "http://10.1.4.192:8803"`
6
+ url: "https://2c6d0136-b6ca-4400-bac5-1703f58abc43.europe-west3-0.gcp.cloud.qdrant.io"
7
+ collection_name: "docling"
8
+ prefer_grpc: true
9
+ api_key: "${QDRANT_API_KEY}" # Load from environment variable
10
+
11
+ retriever:
12
+ model: "BAAI/bge-m3"
13
+ normalize: true
14
+ top_k: 20
15
+
16
+ retrieval:
17
+ use_reranking: true
18
+ reranker_model: "BAAI/bge-reranker-v2-m3"
19
+ reranker_top_k: 5
20
+
21
+ ranker:
22
+ model: "BAAI/bge-reranker-v2-m3"
23
+ top_k: 5
24
+
25
+ bm25:
26
+ top_k: 20
27
+
28
+ hybrid:
29
+ default_mode: "vector_only" # Options: vector_only, sparse_only, hybrid
30
+ default_alpha: 0.5 # Weight for vector scores (0.5 = equal weight)
31
+
32
+ reader:
33
+ default_type: "OPENAI"
34
+ max_tokens: 768
35
+
36
+ # Different LLM provider configurations
37
+ INF_PROVIDERS:
38
+ model: "meta-llama/Llama-3.1-8B-Instruct"
39
+ provider: "nebius"
40
+
41
+ # Not working
42
+ NVIDIA:
43
+ model: "meta-llama/Llama-3.1-8B-Instruct"
44
+ endpoint: "https://huggingface.co/api/integrations/dgx/v1"
45
+
46
+ # Not working
47
+ DEDICATED:
48
+ model: "meta-llama/Llama-3.1-8B-Instruct"
49
+ endpoint: "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud"
50
+
51
+ MISTRAL:
52
+ model: "mistral-medium-latest"
53
+
54
+ OPENAI:
55
+ model: "gpt-4o-mini"
56
+
57
+ OLLAMA:
58
+ model: "mistral-small3.1:24b-instruct-2503-q8_0"
59
+ base_url: "http://10.1.4.192:11434/"
60
+ temperature: 0.8
61
+ num_predict: 256
62
+
63
+ OPENROUTER:
64
+ model: "moonshotai/kimi-k2:free"
65
+ base_url: "https://openrouter.ai/api/v1"
66
+ temperature: 0.7
67
+ max_tokens: 1000
68
+ # site_url: "https://your-site.com" # optional, for OpenRouter ranking
69
+ # site_name: "Your Site Name" # optional, for OpenRouter ranking
70
+
71
+ app:
72
+ dropdown_default: "Annual Consolidated OAG 2024"
73
+
74
+ # File paths
75
+ paths:
76
+ chunks_file: "reports/docling_chunks.json"
77
+ reports_dir: "reports"
78
+
79
+ # Feature toggles
80
+ features:
81
+ enable_session: true
82
+ enable_logging: true
83
+
84
+ # Logging and HuggingFace scheduler configuration
85
+ logging:
86
+ json_dataset_dir: "json_dataset"
87
+ huggingface:
88
+ repo_id: "GIZ/spaces_logs"
89
+ repo_type: "dataset"
90
+ folder_path: "json_dataset"
91
+ path_in_repo: "audit_chatbot"
92
+ token_env_var: "SPACES_LOG"
src/llm/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """LLM adapters and utilities."""
2
+
3
+ from .adapters import LLMRegistry, get_llm_client
4
+ from .templates import get_message_template, PromptTemplate, create_audit_prompt
5
+
6
+ __all__ = ["LLMRegistry", "get_llm_client", "get_message_template", "PromptTemplate", "create_audit_prompt"]
src/llm/adapters.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM client adapters for different providers."""
2
+
3
+ from typing import Dict, Any, List, Optional, Union
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+
7
+ # LangChain imports
8
+ from langchain_mistralai.chat_models import ChatMistralAI
9
+ from langchain_openai.chat_models import ChatOpenAI
10
+ from langchain_ollama import ChatOllama
11
+
12
+ # Legacy client dependencies
13
+ from huggingface_hub import InferenceClient
14
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
+ from langchain_community.llms import HuggingFaceEndpoint
16
+ from langchain_community.chat_models.huggingface import ChatHuggingFace
17
+
18
+ # Configuration loader
19
+ from ..config.loader import load_config
20
+
21
+ # Load configuration once at module level
22
+ _config = load_config()
23
+
24
+
25
+ # Legacy client factory functions (inlined from auditqa_old.reader)
26
+ def _create_inf_provider_client():
27
+ """Create INF_PROVIDERS client."""
28
+ reader_config = _config.get("reader", {})
29
+ inf_config = reader_config.get("INF_PROVIDERS", {})
30
+
31
+ api_key = inf_config.get("api_key")
32
+ if not api_key:
33
+ raise ValueError("INF_PROVIDERS api_key not found in configuration")
34
+
35
+ provider = inf_config.get("provider")
36
+ if not provider:
37
+ raise ValueError("INF_PROVIDERS provider not found in configuration")
38
+
39
+ return InferenceClient(
40
+ provider=provider,
41
+ api_key=api_key,
42
+ bill_to="GIZ",
43
+ )
44
+
45
+
46
+ def _create_nvidia_client():
47
+ """Create NVIDIA client."""
48
+ reader_config = _config.get("reader", {})
49
+ nvidia_config = reader_config.get("NVIDIA", {})
50
+
51
+ api_key = nvidia_config.get("api_key")
52
+ if not api_key:
53
+ raise ValueError("NVIDIA api_key not found in configuration")
54
+
55
+ endpoint = nvidia_config.get("endpoint")
56
+ if not endpoint:
57
+ raise ValueError("NVIDIA endpoint not found in configuration")
58
+
59
+ return InferenceClient(
60
+ base_url=endpoint,
61
+ api_key=api_key
62
+ )
63
+
64
+
65
+ def _create_serverless_client():
66
+ """Create serverless API client."""
67
+ reader_config = _config.get("reader", {})
68
+ serverless_config = reader_config.get("SERVERLESS", {})
69
+
70
+ api_key = serverless_config.get("api_key")
71
+ if not api_key:
72
+ raise ValueError("SERVERLESS api_key not found in configuration")
73
+
74
+ model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct")
75
+
76
+ return InferenceClient(
77
+ model=model_id,
78
+ api_key=api_key,
79
+ )
80
+
81
+
82
+ def _create_dedicated_endpoint_client():
83
+ """Create dedicated endpoint client."""
84
+ reader_config = _config.get("reader", {})
85
+ dedicated_config = reader_config.get("DEDICATED", {})
86
+
87
+ api_key = dedicated_config.get("api_key")
88
+ if not api_key:
89
+ raise ValueError("DEDICATED api_key not found in configuration")
90
+
91
+ endpoint = dedicated_config.get("endpoint")
92
+ if not endpoint:
93
+ raise ValueError("DEDICATED endpoint not found in configuration")
94
+
95
+ max_tokens = dedicated_config.get("max_tokens", 768)
96
+
97
+ # Set up the streaming callback handler
98
+ callback = StreamingStdOutCallbackHandler()
99
+
100
+ # Initialize the HuggingFaceEndpoint with streaming enabled
101
+ llm_qa = HuggingFaceEndpoint(
102
+ endpoint_url=endpoint,
103
+ max_new_tokens=int(max_tokens),
104
+ repetition_penalty=1.03,
105
+ timeout=70,
106
+ huggingfacehub_api_token=api_key,
107
+ streaming=True,
108
+ callbacks=[callback]
109
+ )
110
+
111
+ # Create a ChatHuggingFace instance with the streaming-enabled endpoint
112
+ return ChatHuggingFace(llm=llm_qa)
113
+
114
+
115
+ @dataclass
116
+ class LLMResponse:
117
+ """Standardized LLM response format."""
118
+ content: str
119
+ model: str
120
+ provider: str
121
+ metadata: Dict[str, Any] = None
122
+
123
+
124
+ class BaseLLMAdapter(ABC):
125
+ """Base class for LLM adapters."""
126
+
127
+ def __init__(self, config: Dict[str, Any]):
128
+ self.config = config
129
+
130
+ @abstractmethod
131
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
132
+ """Generate response from messages."""
133
+ pass
134
+
135
+ @abstractmethod
136
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
137
+ """Generate streaming response from messages."""
138
+ pass
139
+
140
+
141
+ class MistralAdapter(BaseLLMAdapter):
142
+ """Adapter for Mistral AI models."""
143
+
144
+ def __init__(self, config: Dict[str, Any]):
145
+ super().__init__(config)
146
+ self.model = ChatMistralAI(
147
+ model=config.get("model", "mistral-medium-latest")
148
+ )
149
+
150
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
151
+ """Generate response using Mistral."""
152
+ response = self.model.invoke(messages)
153
+
154
+ return LLMResponse(
155
+ content=response.content,
156
+ model=self.config.get("model", "mistral-medium-latest"),
157
+ provider="mistral",
158
+ metadata={"usage": getattr(response, 'usage_metadata', {})}
159
+ )
160
+
161
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
162
+ """Generate streaming response using Mistral."""
163
+ for chunk in self.model.stream(messages):
164
+ if chunk.content:
165
+ yield chunk.content
166
+
167
+
168
+ class OpenAIAdapter(BaseLLMAdapter):
169
+ """Adapter for OpenAI models."""
170
+
171
+ def __init__(self, config: Dict[str, Any]):
172
+ super().__init__(config)
173
+ self.model = ChatOpenAI(
174
+ model=config.get("model", "gpt-4o-mini")
175
+ )
176
+
177
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
178
+ """Generate response using OpenAI."""
179
+ response = self.model.invoke(messages)
180
+
181
+ return LLMResponse(
182
+ content=response.content,
183
+ model=self.config.get("model", "gpt-4o-mini"),
184
+ provider="openai",
185
+ metadata={"usage": getattr(response, 'usage_metadata', {})}
186
+ )
187
+
188
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
189
+ """Generate streaming response using OpenAI."""
190
+ for chunk in self.model.stream(messages):
191
+ if chunk.content:
192
+ yield chunk.content
193
+
194
+
195
+ class OllamaAdapter(BaseLLMAdapter):
196
+ """Adapter for Ollama models."""
197
+
198
+ def __init__(self, config: Dict[str, Any]):
199
+ super().__init__(config)
200
+ self.model = ChatOllama(
201
+ model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
202
+ base_url=config.get("base_url", "http://localhost:11434/"),
203
+ temperature=config.get("temperature", 0.8),
204
+ num_predict=config.get("num_predict", 256)
205
+ )
206
+
207
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
208
+ """Generate response using Ollama."""
209
+ response = self.model.invoke(messages)
210
+
211
+ return LLMResponse(
212
+ content=response.content,
213
+ model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
214
+ provider="ollama",
215
+ metadata={}
216
+ )
217
+
218
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
219
+ """Generate streaming response using Ollama."""
220
+ for chunk in self.model.stream(messages):
221
+ if chunk.content:
222
+ yield chunk.content
223
+
224
+
225
+ class OpenRouterAdapter(BaseLLMAdapter):
226
+ """Adapter for OpenRouter models."""
227
+
228
+ def __init__(self, config: Dict[str, Any]):
229
+ super().__init__(config)
230
+
231
+ # Prepare custom headers for OpenRouter (optional)
232
+ headers = {}
233
+ if config.get("site_url"):
234
+ headers["HTTP-Referer"] = config["site_url"]
235
+ if config.get("site_name"):
236
+ headers["X-Title"] = config["site_name"]
237
+
238
+ # Initialize ChatOpenAI with OpenRouter configuration
239
+ self.model = ChatOpenAI(
240
+ model=config.get("model", "openai/gpt-3.5-turbo"),
241
+ api_key=config.get("api_key"),
242
+ base_url=config.get("base_url", "https://openrouter.ai/api/v1"),
243
+ default_headers= headers if headers else {},
244
+ temperature=config.get("temperature", 0.7),
245
+ max_tokens=config.get("max_tokens", 1000)
246
+ )
247
+
248
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
249
+ """Generate response using OpenRouter."""
250
+ response = self.model.invoke(messages)
251
+
252
+ return LLMResponse(
253
+ content=response.content,
254
+ model=self.config.get("model", "openai/gpt-3.5-turbo"),
255
+ provider="openrouter",
256
+ metadata={"usage": getattr(response, 'usage_metadata', {})}
257
+ )
258
+
259
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
260
+ """Generate streaming response using OpenRouter."""
261
+ for chunk in self.model.stream(messages):
262
+ if chunk.content:
263
+ yield chunk.content
264
+
265
+
266
+ class LegacyAdapter(BaseLLMAdapter):
267
+ """Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.)."""
268
+
269
+ def __init__(self, config: Dict[str, Any], client_type: str):
270
+ super().__init__(config)
271
+ self.client_type = client_type
272
+ self.client = self._create_client()
273
+
274
+ def _create_client(self):
275
+ """Create legacy client based on type."""
276
+ if self.client_type == "INF_PROVIDERS":
277
+ return _create_inf_provider_client()
278
+ elif self.client_type == "NVIDIA":
279
+ return _create_nvidia_client()
280
+ elif self.client_type == "DEDICATED":
281
+ return _create_dedicated_endpoint_client()
282
+ else: # SERVERLESS
283
+ return _create_serverless_client()
284
+
285
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
286
+ """Generate response using legacy client."""
287
+ max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768))
288
+
289
+ if self.client_type == "INF_PROVIDERS":
290
+ response = self.client.chat.completions.create(
291
+ model=self.config.get("model"),
292
+ messages=messages,
293
+ max_tokens=max_tokens
294
+ )
295
+ content = response.choices[0].message.content
296
+
297
+ elif self.client_type == "NVIDIA":
298
+ response = self.client.chat_completion(
299
+ model=self.config.get("model"),
300
+ messages=messages,
301
+ max_tokens=max_tokens
302
+ )
303
+ content = response.choices[0].message.content
304
+
305
+ else: # DEDICATED or SERVERLESS
306
+ response = self.client.chat_completion(
307
+ messages=messages,
308
+ max_tokens=max_tokens
309
+ )
310
+ content = response.choices[0].message.content
311
+
312
+ return LLMResponse(
313
+ content=content,
314
+ model=self.config.get("model", "unknown"),
315
+ provider=self.client_type.lower(),
316
+ metadata={}
317
+ )
318
+
319
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
320
+ """Generate streaming response using legacy client."""
321
+ # Legacy clients may not support streaming in the same way
322
+ # This is a simplified implementation
323
+ response = self.generate(messages, **kwargs)
324
+ words = response.content.split()
325
+ for word in words:
326
+ yield word + " "
327
+
328
+
329
+ class LLMRegistry:
330
+ """Registry for managing different LLM adapters."""
331
+
332
+ def __init__(self):
333
+ self.adapters = {}
334
+ self.adapter_configs = {}
335
+
336
+ def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]):
337
+ """Register an LLM adapter (lazy instantiation)."""
338
+ self.adapter_configs[name] = (adapter_class, config)
339
+
340
+ def get_adapter(self, name: str) -> BaseLLMAdapter:
341
+ """Get an LLM adapter by name (lazy instantiation)."""
342
+ if name not in self.adapter_configs:
343
+ raise ValueError(f"Unknown LLM adapter: {name}")
344
+
345
+ # Lazy instantiation - only create when needed
346
+ if name not in self.adapters:
347
+ adapter_class, config = self.adapter_configs[name]
348
+ self.adapters[name] = adapter_class(config)
349
+
350
+ return self.adapters[name]
351
+
352
+ def list_adapters(self) -> List[str]:
353
+ """List available adapter names."""
354
+ return list(self.adapter_configs.keys())
355
+
356
+
357
+ def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry:
358
+ """
359
+ Create and populate LLM registry from configuration.
360
+
361
+ Args:
362
+ config: Configuration dictionary
363
+
364
+ Returns:
365
+ Populated LLMRegistry
366
+ """
367
+ registry = LLMRegistry()
368
+ reader_config = config.get("reader", {})
369
+
370
+ # Register simple adapters
371
+ if "MISTRAL" in reader_config:
372
+ registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"])
373
+
374
+ if "OPENAI" in reader_config:
375
+ registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"])
376
+
377
+ if "OLLAMA" in reader_config:
378
+ registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"])
379
+
380
+ if "OPENROUTER" in reader_config:
381
+ registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"])
382
+
383
+ # Register legacy adapters
384
+ # legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"]
385
+ legacy_types = ["INF_PROVIDERS"]
386
+ for legacy_type in legacy_types:
387
+ if legacy_type in reader_config:
388
+ registry.register_adapter(
389
+ legacy_type.lower(),
390
+ lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt),
391
+ reader_config[legacy_type]
392
+ )
393
+
394
+ return registry
395
+
396
+
397
+ def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter:
398
+ """
399
+ Get LLM client for specified provider.
400
+
401
+ Args:
402
+ provider: Provider name (mistral, openai, ollama, etc.)
403
+ config: Configuration dictionary
404
+
405
+ Returns:
406
+ LLM adapter instance
407
+ """
408
+ registry = create_llm_registry(config)
409
+ return registry.get_adapter(provider)
src/llm/templates.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM prompt templates and message formatting utilities."""
2
+
3
+ from typing import List, Dict, Any, Union
4
+ from dataclasses import dataclass
5
+ from langchain.schema import SystemMessage, HumanMessage
6
+
7
+
8
+ @dataclass
9
+ class PromptTemplate:
10
+ """Template for managing prompts with variables."""
11
+
12
+ system_prompt: str
13
+ user_prompt_template: str
14
+
15
+ def format(self, **kwargs) -> tuple:
16
+ """Format the template with provided variables."""
17
+ formatted_user = self.user_prompt_template.format(**kwargs)
18
+ return self.system_prompt, formatted_user
19
+
20
+
21
+ # Default system prompt for audit Q&A
22
+ DEFAULT_AUDIT_SYSTEM_PROMPT = """
23
+ You are AuditQ&A, an AI Assistant for audit reports. Answer questions directly and factually based on the provided context.
24
+
25
+ Guidelines:
26
+ - Answer directly and concisely (2-3 sentences maximum)
27
+ - Use specific facts and numbers from the context
28
+ - Cite sources using [Doc i] format
29
+ - Be factual, not opinionated
30
+ - Avoid phrases like "From my point of view", "I think", "It seems"
31
+
32
+ Examples:
33
+
34
+ Query: "What challenges arise from contradictory PDM implementation guidelines?"
35
+ Context: [Retrieved documents about PDM guidelines contradictions]
36
+ Answer: "Contradictory PDM implementation guidelines cause challenges during implementation, as entities receive numerous and often conflicting directives from different authorities. For example, guidelines on transfer of funds to PDM SACCOs differ between the PDM Secretariat and PSST, and there are conflicting directives on fund diversion from various authorities."
37
+
38
+ Query: "What was the supplementary funding obtained for the wage budget?"
39
+ Context: [Retrieved documents about wage budget funding]
40
+ Answer: "The supplementary funding obtained for the wage budget was UGX.2,208,040,656."
41
+
42
+ Now answer the following question based on the provided context:
43
+ """
44
+
45
+ # Default user prompt template
46
+ DEFAULT_USER_PROMPT_TEMPLATE = """Passages:
47
+ {context}
48
+ -----------------------
49
+ Question: {question} - Explained to audit expert
50
+ Answer in english with the passages citations:
51
+ """
52
+
53
+
54
+ def create_audit_prompt(context_list: List[str], query: str) -> List[Dict[str, str]]:
55
+ """
56
+ Create audit Q&A prompt messages from context and query.
57
+
58
+ Args:
59
+ context_list: List of context passages
60
+ query: User query
61
+
62
+ Returns:
63
+ List of message dictionaries for LLM
64
+ """
65
+ # Join context passages with numbering
66
+ numbered_context = []
67
+ for i, passage in enumerate(context_list, 1):
68
+ numbered_context.append(f"Doc {i}: {passage}")
69
+
70
+ context_str = "\n\n".join(numbered_context)
71
+
72
+ # Format user prompt
73
+ user_prompt = DEFAULT_USER_PROMPT_TEMPLATE.format(
74
+ context=context_str,
75
+ question=query
76
+ )
77
+
78
+ # Return as message format
79
+ messages = [
80
+ {"role": "system", "content": DEFAULT_AUDIT_SYSTEM_PROMPT},
81
+ {"role": "user", "content": user_prompt}
82
+ ]
83
+
84
+ return messages
85
+
86
+
87
+ def get_message_template(
88
+ provider_type: str,
89
+ system_prompt: str,
90
+ user_prompt: str
91
+ ) -> List[Union[Dict[str, str], SystemMessage, HumanMessage]]:
92
+ """
93
+ Get message template based on LLM provider type.
94
+
95
+ Args:
96
+ provider_type: Type of LLM provider
97
+ system_prompt: System prompt content
98
+ user_prompt: User prompt content
99
+
100
+ Returns:
101
+ List of messages in the appropriate format for the provider
102
+ """
103
+ provider_type = provider_type.upper()
104
+
105
+ if provider_type in ['NVIDIA', 'INF_PROVIDERS', 'MISTRAL', 'OPENAI', 'OPENROUTER']:
106
+ # Dictionary format for API-based providers
107
+ messages = [
108
+ {"role": "system", "content": system_prompt},
109
+ {"role": "user", "content": user_prompt}
110
+ ]
111
+ elif provider_type in ['DEDICATED', 'SERVERLESS', 'OLLAMA']:
112
+ # LangChain message objects for local/dedicated providers
113
+ messages = [
114
+ SystemMessage(content=system_prompt),
115
+ HumanMessage(content=user_prompt)
116
+ ]
117
+ else:
118
+ # Default to dictionary format
119
+ messages = [
120
+ {"role": "system", "content": system_prompt},
121
+ {"role": "user", "content": user_prompt}
122
+ ]
123
+
124
+ return messages
125
+
126
+
127
+ def create_custom_prompt_template(
128
+ system_prompt: str,
129
+ user_template: str
130
+ ) -> PromptTemplate:
131
+ """
132
+ Create a custom prompt template.
133
+
134
+ Args:
135
+ system_prompt: System prompt content
136
+ user_template: User prompt template with placeholders
137
+
138
+ Returns:
139
+ PromptTemplate instance
140
+ """
141
+ return PromptTemplate(
142
+ system_prompt=system_prompt,
143
+ user_prompt_template=user_template
144
+ )
145
+
146
+
147
+ def create_evaluation_prompt(context_list: List[str], query: str, expected_answer: str) -> List[Dict[str, str]]:
148
+ """
149
+ Create prompt for evaluation purposes with expected answer.
150
+
151
+ Args:
152
+ context_list: List of context passages
153
+ query: User query
154
+ expected_answer: Expected/ground truth answer
155
+
156
+ Returns:
157
+ List of message dictionaries for evaluation
158
+ """
159
+ # Join context passages
160
+ context_str = "\n\n".join([f"Doc {i}: {passage}" for i, passage in enumerate(context_list, 1)])
161
+
162
+ evaluation_system_prompt = """
163
+ You are an evaluation assistant. Given context passages, a question, and an expected answer,
164
+ evaluate how well the provided context supports answering the question accurately.
165
+
166
+ Provide your evaluation focusing on:
167
+ 1. Relevance of the context to the question
168
+ 2. Completeness of information needed to answer
169
+ 3. Quality and accuracy of supporting details
170
+ """
171
+
172
+ user_prompt = f"""Context Passages:
173
+ {context_str}
174
+
175
+ Question: {query}
176
+ Expected Answer: {expected_answer}
177
+
178
+ Evaluation:"""
179
+
180
+ return [
181
+ {"role": "system", "content": evaluation_system_prompt},
182
+ {"role": "user", "content": user_prompt}
183
+ ]
184
+
185
+
186
+ def get_prompt_variants() -> Dict[str, PromptTemplate]:
187
+ """
188
+ Get different prompt template variants for testing.
189
+
190
+ Returns:
191
+ Dictionary of named prompt templates
192
+ """
193
+ variants = {
194
+ "standard": create_custom_prompt_template(
195
+ DEFAULT_AUDIT_SYSTEM_PROMPT,
196
+ DEFAULT_USER_PROMPT_TEMPLATE
197
+ ),
198
+
199
+ "concise": create_custom_prompt_template(
200
+ """You are an audit report AI assistant. Provide clear, concise answers based on the given context passages. Always cite sources using [Doc i] format.""",
201
+ """Context:\n{context}\n\nQuestion: {question}\nAnswer:"""
202
+ ),
203
+
204
+ "detailed": create_custom_prompt_template(
205
+ DEFAULT_AUDIT_SYSTEM_PROMPT + """\n\nAdditional Instructions:
206
+ - Provide detailed explanations with specific examples
207
+ - Include relevant numbers, dates, and financial figures when available
208
+ - Structure your response with clear headings when appropriate
209
+ - Explain the significance of findings in the context of governance and accountability""",
210
+ DEFAULT_USER_PROMPT_TEMPLATE
211
+ )
212
+ }
213
+
214
+ return variants
215
+
216
+
217
+ # Backward compatibility function
218
+ def format_context_with_citations(context_list: List[str]) -> str:
219
+ """
220
+ Format context list with document citations.
221
+
222
+ Args:
223
+ context_list: List of context passages
224
+
225
+ Returns:
226
+ Formatted context string with citations
227
+ """
228
+ formatted_passages = []
229
+ for i, passage in enumerate(context_list, 1):
230
+ formatted_passages.append(f"Doc {i}: {passage}")
231
+
232
+ return "\n\n".join(formatted_passages)
src/loader.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading utilities for chunks and JSON files."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import List, Dict, Any
6
+ from langchain.docstore.document import Document
7
+
8
+
9
+ def load_json(filepath: Path | str) -> List[Dict[str, Any]]:
10
+ """
11
+ Load JSON data from file.
12
+
13
+ Args:
14
+ filepath: Path to JSON file
15
+
16
+ Returns:
17
+ List of dictionaries containing the JSON data
18
+ """
19
+ filepath = Path(filepath)
20
+
21
+ if not filepath.exists():
22
+ raise FileNotFoundError(f"JSON file not found: {filepath}")
23
+
24
+ with open(filepath, 'r', encoding='utf-8') as f:
25
+ data = json.load(f)
26
+
27
+ return data
28
+
29
+
30
+ def open_file(filepath: Path | str) -> str:
31
+ """
32
+ Open and read a text file.
33
+
34
+ Args:
35
+ filepath: Path to text file
36
+
37
+ Returns:
38
+ File contents as string
39
+ """
40
+ filepath = Path(filepath)
41
+
42
+ if not filepath.exists():
43
+ raise FileNotFoundError(f"File not found: {filepath}")
44
+
45
+ with open(filepath, 'r', encoding='utf-8') as f:
46
+ content = f.read()
47
+
48
+ return content
49
+
50
+
51
+ def load_chunks(chunks_file: Path | str = None) -> List[Dict[str, Any]]:
52
+ """
53
+ Load document chunks from JSON file.
54
+
55
+ Args:
56
+ chunks_file: Path to chunks JSON file. If None, uses default path.
57
+
58
+ Returns:
59
+ List of chunk dictionaries
60
+ """
61
+ if chunks_file is None:
62
+ chunks_file = Path("reports/docling_chunks.json")
63
+
64
+ return load_json(chunks_file)
65
+
66
+
67
+ def chunks_to_documents(chunks: List[Dict[str, Any]]) -> List[Document]:
68
+ """
69
+ Convert chunk dictionaries to LangChain Document objects.
70
+
71
+ Args:
72
+ chunks: List of chunk dictionaries
73
+
74
+ Returns:
75
+ List of Document objects
76
+ """
77
+ documents = []
78
+
79
+ for chunk in chunks:
80
+ doc = Document(
81
+ page_content=chunk.get("content", ""),
82
+ metadata=chunk.get("metadata", {})
83
+ )
84
+ documents.append(doc)
85
+
86
+ return documents
87
+
88
+
89
+ def validate_chunks(chunks: List[Dict[str, Any]]) -> bool:
90
+ """
91
+ Validate that chunks have required fields.
92
+
93
+ Args:
94
+ chunks: List of chunk dictionaries
95
+
96
+ Returns:
97
+ True if valid, raises ValueError if invalid
98
+ """
99
+ required_fields = ["content", "metadata"]
100
+
101
+ for i, chunk in enumerate(chunks):
102
+ for field in required_fields:
103
+ if field not in chunk:
104
+ raise ValueError(f"Chunk {i} missing required field: {field}")
105
+
106
+ # Validate metadata has required fields
107
+ metadata = chunk["metadata"]
108
+ if not isinstance(metadata, dict):
109
+ raise ValueError(f"Chunk {i} metadata must be a dictionary")
110
+
111
+ # Check for common metadata fields
112
+ if "filename" not in metadata:
113
+ raise ValueError(f"Chunk {i} metadata missing 'filename' field")
114
+
115
+ return True
src/logging.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging utilities (placeholder for legacy compatibility)."""
2
+ import json
3
+ import logging
4
+ from uuid import uuid4
5
+ from pathlib import Path
6
+ from threading import Lock
7
+ from datetime import datetime
8
+ from typing import Dict, Any, Optional
9
+
10
+ from .config import load_config
11
+
12
+ def save_logs(
13
+ scheduler=None,
14
+ json_dataset_path: Path = None,
15
+ logs_data: Dict[str, Any] = None,
16
+ feedback: str = None
17
+ ) -> None:
18
+ """
19
+ Save logs (placeholder for legacy compatibility).
20
+
21
+ Args:
22
+ scheduler: HuggingFace scheduler (not used in refactored version)
23
+ json_dataset_path: Path to JSON dataset
24
+ logs_data: Log data dictionary
25
+ feedback: User feedback
26
+
27
+ Note:
28
+ This is a placeholder function for backward compatibility.
29
+ In the refactored version, logging would be handled differently.
30
+ """
31
+ if not is_logging_enabled():
32
+ return
33
+ try:
34
+ current_time = datetime.now().timestamp()
35
+ logs_data["time"] = str(current_time)
36
+ if feedback:
37
+ logs_data["feedback"] = feedback
38
+ logs_data["record_id"] = str(uuid4())
39
+ field_order = [
40
+ "record_id",
41
+ "session_id",
42
+ "time",
43
+ "session_duration_seconds",
44
+ "client_location",
45
+ "platform",
46
+ "system_prompt",
47
+ "sources",
48
+ "reports",
49
+ "subtype",
50
+ "year",
51
+ "question",
52
+ "retriever",
53
+ "endpoint_type",
54
+ "reader",
55
+ "docs",
56
+ "answer",
57
+ "feedback"
58
+ ]
59
+ ordered_logs = {k: logs_data.get(k) for k in field_order if k in logs_data}
60
+ lock = getattr(scheduler, "lock", None)
61
+ if lock is None:
62
+ lock = Lock()
63
+ with lock:
64
+ with open(json_dataset_path, 'a') as f:
65
+ json.dump(ordered_logs, f)
66
+ f.write("\n")
67
+ logging.info("logging done")
68
+ except Exception as e:
69
+ logging.error(f"Error saving logs: {e}")
70
+ raise
71
+
72
+
73
+ def setup_logging(log_level: str = "INFO", log_file: str = None) -> None:
74
+ """
75
+ Set up logging configuration.
76
+
77
+ Args:
78
+ log_level: Logging level
79
+ log_file: Optional log file path
80
+ """
81
+ if not is_logging_enabled():
82
+ return
83
+
84
+ # Configure logging
85
+ logging.basicConfig(
86
+ level=getattr(logging, log_level.upper()),
87
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
88
+ handlers=[
89
+ logging.StreamHandler(),
90
+ logging.FileHandler(log_file) if log_file else logging.NullHandler()
91
+ ]
92
+ )
93
+
94
+
95
+ def log_query_response(
96
+ query: str,
97
+ response: str,
98
+ metadata: Dict[str, Any] = None
99
+ ) -> None:
100
+ """
101
+ Log query and response for analysis.
102
+
103
+ Args:
104
+ query: User query
105
+ response: System response
106
+ metadata: Additional metadata
107
+ """
108
+ if not is_logging_enabled():
109
+ return
110
+
111
+ logger = logging.getLogger(__name__)
112
+
113
+ log_entry = {
114
+ "query": query,
115
+ "response_length": len(response),
116
+ "metadata": metadata or {}
117
+ }
118
+
119
+ logger.info(f"Query processed: {log_entry}")
120
+
121
+
122
+ def log_error(error: Exception, context: Dict[str, Any] = None) -> None:
123
+ """
124
+ Log error with context.
125
+
126
+ Args:
127
+ error: Exception that occurred
128
+ context: Additional context information
129
+ """
130
+ if not is_logging_enabled():
131
+ return
132
+
133
+ logger = logging.getLogger(__name__)
134
+
135
+ error_info = {
136
+ "error_type": type(error).__name__,
137
+ "error_message": str(error),
138
+ "context": context or {}
139
+ }
140
+
141
+ logger.error(f"Error occurred: {error_info}")
142
+
143
+
144
+ def log_performance_metrics(
145
+ operation: str,
146
+ duration: float,
147
+ metadata: Dict[str, Any] = None
148
+ ) -> None:
149
+ """
150
+ Log performance metrics.
151
+
152
+ Args:
153
+ operation: Name of the operation
154
+ duration: Duration in seconds
155
+ metadata: Additional metadata
156
+ """
157
+ if not is_logging_enabled():
158
+ return
159
+
160
+ logger = logging.getLogger(__name__)
161
+
162
+ metrics = {
163
+ "operation": operation,
164
+ "duration_seconds": duration,
165
+ "metadata": metadata or {}
166
+ }
167
+
168
+ logger.info(f"Performance metrics: {metrics}")
169
+
170
+
171
+ def is_session_enabled() -> bool:
172
+ """
173
+ Returns True if session management is enabled, False otherwise.
174
+ Checks environment variable ENABLE_SESSION first, then config.
175
+ """
176
+ env = os.getenv("ENABLE_SESSION")
177
+ if env is not None:
178
+ return env.lower() in ("1", "true", "yes", "on")
179
+ config = load_config()
180
+ return config.get("features", {}).get("enable_session", True)
181
+
182
+
183
+ def is_logging_enabled() -> bool:
184
+ """
185
+ Returns True if logging is enabled, False otherwise.
186
+ Checks environment variable ENABLE_LOGGING first, then config.
187
+ """
188
+ env = os.getenv("ENABLE_LOGGING")
189
+ if env is not None:
190
+ return env.lower() in ("1", "true", "yes", "on")
191
+ config = load_config()
192
+ return config.get("features", {}).get("enable_logging", True)
193
+
src/pipeline.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main pipeline orchestrator for the Audit QA system."""
2
+ import time
3
+ from pathlib import Path
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Any, List, Optional
6
+
7
+ from langchain.docstore.document import Document
8
+
9
+ from .logging import log_error
10
+ from .llm.adapters import LLMRegistry
11
+ from .loader import chunks_to_documents
12
+ from .vectorstore import VectorStoreManager
13
+ from .retrieval.context import ContextRetriever
14
+ from .config.loader import get_embedding_model_for_collection
15
+
16
+
17
+
18
+ @dataclass
19
+ class PipelineResult:
20
+ """Result of pipeline execution."""
21
+ answer: str
22
+ sources: List[Document]
23
+ execution_time: float
24
+ metadata: Dict[str, Any]
25
+ query: str = "" # Add default value for query
26
+
27
+ def __post_init__(self):
28
+ """Post-initialization processing."""
29
+ if not self.query:
30
+ self.query = "Unknown query"
31
+
32
+
33
+ class PipelineManager:
34
+ """Main pipeline manager for the RAG system."""
35
+
36
+ def __init__(self, config: dict = None):
37
+ """
38
+ Initialize the pipeline manager.
39
+ """
40
+ self.config = config or {}
41
+ self.vectorstore_manager = None
42
+ self.context_retriever = None # Initialize as None
43
+ self.llm_client = None
44
+ self.report_service = None
45
+ self.chunks = None
46
+
47
+ # Initialize components
48
+ self._initialize_components()
49
+
50
+ def update_config(self, new_config: dict):
51
+ """
52
+ Update the pipeline configuration.
53
+ This is useful for experiments that need different settings.
54
+ """
55
+ if not isinstance(new_config, dict):
56
+ return
57
+
58
+ # Deep merge the new config with existing config
59
+ def deep_merge(base_dict, update_dict):
60
+ for key, value in update_dict.items():
61
+ if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
62
+ deep_merge(base_dict[key], value)
63
+ else:
64
+ base_dict[key] = value
65
+
66
+ deep_merge(self.config, new_config)
67
+
68
+ # Auto-infer embedding model from collection name if not "docling"
69
+ collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
70
+ if collection_name != 'docling':
71
+ inferred_model = get_embedding_model_for_collection(collection_name)
72
+ if inferred_model:
73
+ print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
74
+ if 'retriever' not in self.config:
75
+ self.config['retriever'] = {}
76
+ self.config['retriever']['model'] = inferred_model
77
+ # Set default normalize parameter if not present
78
+ if 'normalize' not in self.config['retriever']:
79
+ self.config['retriever']['normalize'] = True
80
+
81
+ # Also update vectorstore config if it exists
82
+ if 'vectorstore' in self.config:
83
+ self.config['vectorstore']['embedding_model'] = inferred_model
84
+
85
+ print(f"🔧 CONFIG UPDATED: Pipeline config updated with experiment settings")
86
+
87
+ # Re-initialize vectorstore manager with updated config
88
+ self._reinitialize_vectorstore_manager()
89
+
90
+ def _reinitialize_vectorstore_manager(self):
91
+ """Re-initialize vectorstore manager with current config."""
92
+ try:
93
+ self.vectorstore_manager = VectorStoreManager(self.config)
94
+ print("🔄 VectorStore manager re-initialized with updated config")
95
+ except Exception as e:
96
+ print(f"❌ Error re-initializing vectorstore manager: {e}")
97
+
98
+ def _get_reranker_model_name(self) -> str:
99
+ """
100
+ Get the reranker model name from configuration.
101
+
102
+ Returns:
103
+ Reranker model name or default
104
+ """
105
+ return (
106
+ self.config.get('retrieval', {}).get('reranker_model') or
107
+ self.config.get('ranker', {}).get('model') or
108
+ self.config.get('reranker_model') or
109
+ 'BAAI/bge-reranker-v2-m3'
110
+ )
111
+
112
+ def _initialize_components(self):
113
+ """Initialize pipeline components."""
114
+ try:
115
+ # Load config if not provided
116
+ if not self.config:
117
+ from auditqa.config.loader import load_config
118
+ self.config = load_config()
119
+
120
+ # Auto-infer embedding model from collection name if not "docling"
121
+ collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
122
+ if collection_name != 'docling':
123
+ inferred_model = get_embedding_model_for_collection(collection_name)
124
+ if inferred_model:
125
+ print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
126
+ if 'retriever' not in self.config:
127
+ self.config['retriever'] = {}
128
+ self.config['retriever']['model'] = inferred_model
129
+ # Set default normalize parameter if not present
130
+ if 'normalize' not in self.config['retriever']:
131
+ self.config['retriever']['normalize'] = True
132
+
133
+ # Also update vectorstore config if it exists
134
+ if 'vectorstore' in self.config:
135
+ self.config['vectorstore']['embedding_model'] = inferred_model
136
+
137
+ self.vectorstore_manager = VectorStoreManager(self.config)
138
+
139
+ self.llm_manager = LLMRegistry()
140
+
141
+ # Try to get LLM client using the correct method
142
+ self.llm_client = None
143
+ try:
144
+ # Try using get_adapter method (most likely correct)
145
+ self.llm_client = self.llm_manager.get_adapter("openai")
146
+ print("✅ LLM CLIENT: Initialized using get_adapter method")
147
+ except Exception as e:
148
+ try:
149
+ # Try direct instantiation with config
150
+ from auditqa.llm.adapters import get_llm_client
151
+ self.llm_client = get_llm_client("openai", self.config)
152
+ print("✅ LLM CLIENT: Initialized using direct get_llm_client function with config")
153
+ except Exception as e2:
154
+ print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
155
+ # Try to create a simple LLM client directly
156
+ try:
157
+ from langchain_openai import ChatOpenAI
158
+ import os
159
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
160
+ if api_key:
161
+ self.llm_client = ChatOpenAI(
162
+ model="gpt-3.5-turbo",
163
+ api_key=api_key,
164
+ temperature=0.1,
165
+ max_tokens=1000
166
+ )
167
+ print("✅ LLM CLIENT: Initialized using direct ChatOpenAI")
168
+ else:
169
+ print("❌ LLM CLIENT: No API key available")
170
+ except Exception as e3:
171
+ print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
172
+ self.llm_client = None
173
+
174
+ # Load system prompt
175
+ from auditqa.llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
176
+ self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
177
+
178
+ # Initialize report service
179
+ try:
180
+ from auditqa.reporting.service import ReportService
181
+ self.report_service = ReportService()
182
+ except Exception as e:
183
+ print(f"Warning: Could not initialize report service: {e}")
184
+ self.report_service = None
185
+
186
+ except Exception as e:
187
+ print(f"Warning: Error initializing components: {e}")
188
+
189
+ def test_retrieval(
190
+ self,
191
+ query: str,
192
+ reports: List[str] = None,
193
+ sources: str = None,
194
+ subtype: List[str] = None,
195
+ k: int = None,
196
+ search_mode: str = None,
197
+ search_alpha: float = None,
198
+ use_reranking: bool = True
199
+ ) -> Dict[str, Any]:
200
+ """
201
+ Test retrieval only without LLM inference.
202
+
203
+ Args:
204
+ query: User query
205
+ reports: List of specific report filenames
206
+ sources: Source category
207
+ subtype: List of subtypes
208
+ k: Number of documents to retrieve
209
+ search_mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
210
+ search_alpha: Weight for vector scores in hybrid mode
211
+ use_reranking: Whether to use reranking
212
+
213
+ Returns:
214
+ Dictionary with retrieval results and metadata
215
+ """
216
+ start_time = time.time()
217
+
218
+ try:
219
+ # Set default search parameters if not provided
220
+ if search_mode is None:
221
+ search_mode = self.config.get("hybrid", {}).get("default_mode", "vector_only")
222
+ if search_alpha is None:
223
+ search_alpha = self.config.get("hybrid", {}).get("default_alpha", 0.5)
224
+
225
+ # Get vector store
226
+ vectorstore = self.vectorstore_manager.get_vectorstore()
227
+ if not vectorstore:
228
+ raise ValueError(
229
+ "Vector store not available. Call connect_vectorstore() or create_vectorstore() first."
230
+ )
231
+
232
+ # Retrieve context with scores for test retrieval
233
+ context_docs_with_scores = self.context_retriever.retrieve_with_scores(
234
+ vectorstore=vectorstore,
235
+ query=query,
236
+ reports=reports,
237
+ sources=sources,
238
+ subtype=subtype,
239
+ k=k,
240
+ search_mode=search_mode,
241
+ alpha=search_alpha,
242
+ )
243
+
244
+ # Extract documents and scores
245
+ context_docs = [doc for doc, score in context_docs_with_scores]
246
+ context_scores = [score for doc, score in context_docs_with_scores]
247
+
248
+ execution_time = time.time() - start_time
249
+
250
+ # Format results with actual scores
251
+ results = []
252
+ for i, (doc, score) in enumerate(zip(context_docs, context_scores)):
253
+ results.append({
254
+ "rank": i + 1,
255
+ "content": doc.page_content, # Return full content without truncation
256
+ "metadata": doc.metadata,
257
+ "score": score if score is not None else 0.0
258
+ })
259
+
260
+ return {
261
+ "results": results,
262
+ "num_results": len(results),
263
+ "execution_time": execution_time,
264
+ "search_mode": search_mode,
265
+ "search_alpha": search_alpha,
266
+ "query": query
267
+ }
268
+
269
+ except Exception as e:
270
+ print(f"❌ Error during retrieval test: {e}")
271
+ log_error(e, {"component": "retrieval_test", "query": query})
272
+ return {
273
+ "results": [],
274
+ "num_results": 0,
275
+ "execution_time": time.time() - start_time,
276
+ "error": str(e),
277
+ "search_mode": search_mode or "unknown",
278
+ "search_alpha": search_alpha or 0.5,
279
+ "query": query
280
+ }
281
+
282
+ def connect_vectorstore(self, force_recreate: bool = False) -> bool:
283
+ """
284
+ Connect to existing vector store.
285
+
286
+ Args:
287
+ force_recreate: If True, recreate the collection if dimension mismatch occurs
288
+
289
+ Returns:
290
+ True if successful, False otherwise
291
+ """
292
+ try:
293
+ vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=force_recreate)
294
+ if vectorstore:
295
+ print("✅ Connected to vector store")
296
+ return True
297
+ else:
298
+ print("❌ Failed to connect to vector store")
299
+ return False
300
+ except Exception as e:
301
+ print(f"❌ Error connecting to vector store: {e}")
302
+ log_error(e, {"component": "vectorstore_connection"})
303
+
304
+ # If it's a dimension mismatch error, try with force_recreate
305
+ if "dimensions" in str(e).lower() and not force_recreate:
306
+ print("🔄 Dimension mismatch detected, attempting to recreate collection...")
307
+ try:
308
+ vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=True)
309
+ if vectorstore:
310
+ print("✅ Connected to vector store (recreated)")
311
+ return True
312
+ except Exception as recreate_error:
313
+ print(f"❌ Failed to recreate vector store: {recreate_error}")
314
+ log_error(recreate_error, {"component": "vectorstore_recreation"})
315
+
316
+ return False
317
+
318
+ def create_vectorstore(self) -> bool:
319
+ """
320
+ Create new vector store from chunks.
321
+
322
+ Returns:
323
+ True if successful, False otherwise
324
+ """
325
+ try:
326
+ if not self.chunks:
327
+ raise ValueError("No chunks available for vector store creation")
328
+
329
+ documents = chunks_to_documents(self.chunks)
330
+ self.vectorstore_manager.create_from_documents(documents)
331
+ print("✅ Vector store created successfully")
332
+ return True
333
+ except Exception as e:
334
+ print(f"❌ Error creating vector store: {e}")
335
+ log_error(e, {"component": "vectorstore_creation"})
336
+ return False
337
+
338
+ def create_audit_prompt(self, query: str, context_docs: List[Document]) -> str:
339
+ """Create a prompt for the LLM to generate an answer."""
340
+ try:
341
+ # Ensure query is not None
342
+ if not query or not isinstance(query, str) or query.strip() == "":
343
+ return "Error: No query provided"
344
+
345
+ # Ensure context_docs is not None and is a list
346
+ if context_docs is None:
347
+ context_docs = []
348
+
349
+ # Filter out None documents and ensure they have content
350
+ valid_docs = []
351
+ for doc in context_docs:
352
+ if doc is not None:
353
+ if hasattr(doc, 'page_content') and doc.page_content and isinstance(doc.page_content, str):
354
+ valid_docs.append(doc)
355
+ elif isinstance(doc, str) and doc.strip():
356
+ valid_docs.append(doc)
357
+
358
+ # Create context string
359
+ if valid_docs:
360
+ context_parts = []
361
+ for i, doc in enumerate(valid_docs, 1):
362
+ if hasattr(doc, 'page_content') and doc.page_content:
363
+ context_parts.append(f"Doc {i}: {doc.page_content}")
364
+ elif isinstance(doc, str) and doc.strip():
365
+ context_parts.append(f"Doc {i}: {doc}")
366
+
367
+ context_string = "\n\n".join(context_parts)
368
+ else:
369
+ context_string = "No relevant context found."
370
+
371
+ # Create the prompt
372
+ prompt = f"""
373
+ {self.system_prompt}
374
+
375
+ Context:
376
+ {context_string}
377
+
378
+ Query: {query}
379
+
380
+ Answer:"""
381
+
382
+ return prompt
383
+
384
+ except Exception as e:
385
+ print(f"Error creating audit prompt: {e}")
386
+ return f"Error creating prompt: {e}"
387
+
388
+ def _generate_answer(self, prompt: str) -> str:
389
+ """Generate answer using the LLM."""
390
+ try:
391
+ if not prompt or not isinstance(prompt, str) or prompt.strip() == "":
392
+ return "Error: No prompt provided"
393
+
394
+ # Ensure LLM client is available
395
+ if not self.llm_client:
396
+ return "Error: LLM client not available"
397
+
398
+ # Generate response using the correct method
399
+ if hasattr(self.llm_client, 'generate'):
400
+ # Use the generate method (for adapters)
401
+ response = self.llm_client.generate([{"role": "user", "content": prompt}])
402
+
403
+ # Extract content from LLMResponse
404
+ if hasattr(response, 'content'):
405
+ answer = response.content
406
+ else:
407
+ answer = str(response)
408
+
409
+ elif hasattr(self.llm_client, 'invoke'):
410
+ # Use the invoke method (for direct LangChain models)
411
+ response = self.llm_client.invoke(prompt)
412
+
413
+ # Extract content safely
414
+ if hasattr(response, 'content') and response.content is not None:
415
+ answer = response.content
416
+ elif isinstance(response, str) and response.strip():
417
+ answer = response
418
+ else:
419
+ answer = str(response) if response is not None else "Error: LLM returned None response"
420
+ else:
421
+ return "Error: LLM client has no generate or invoke method"
422
+
423
+ # Ensure answer is not None and is a string
424
+ if answer is None or not isinstance(answer, str):
425
+ return "Error: LLM returned invalid response"
426
+
427
+ return answer.strip()
428
+
429
+ except Exception as e:
430
+ print(f"Error generating answer: {e}")
431
+ return f"Error generating answer: {e}"
432
+
433
+ def run(
434
+ self,
435
+ query: str,
436
+ reports: List[str] = None,
437
+ sources: List[str] = None,
438
+ subtype: List[str] = None,
439
+ llm_provider: str = None,
440
+ use_reranking: bool = True,
441
+ search_mode: str = None,
442
+ search_alpha: float = None,
443
+ auto_infer_filters: bool = True,
444
+ filters: Dict[str, Any] = None,
445
+ ) -> PipelineResult:
446
+ """
447
+ Run the complete RAG pipeline.
448
+
449
+ Args:
450
+ query: User query
451
+ reports: List of specific report filenames
452
+ sources: Source category filter
453
+ subtype: List of subtypes/filenames
454
+ llm_provider: LLM provider to use
455
+ use_reranking: Whether to use reranking
456
+ search_mode: Search mode (vector, sparse, hybrid)
457
+ search_alpha: Alpha value for hybrid search
458
+ auto_infer_filters: Whether to auto-infer filters from query
459
+
460
+ Returns:
461
+ PipelineResult object
462
+ """
463
+ try:
464
+ # Validate input
465
+ if not query or not isinstance(query, str) or query.strip() == "":
466
+ return PipelineResult(
467
+ answer="Error: Invalid query provided",
468
+ sources=[],
469
+ execution_time=0.0,
470
+ metadata={'error': 'Invalid query'},
471
+ query=query
472
+ )
473
+
474
+ # Ensure lists are not None
475
+ if reports is None:
476
+ reports = []
477
+ if subtype is None:
478
+ subtype = []
479
+
480
+ start_time = time.time()
481
+
482
+ # Auto-infer filters if enabled and no explicit filters provided
483
+ inferred_filters = {}
484
+ filters_applied = False
485
+ qdrant_filter = None # Add this
486
+
487
+ if auto_infer_filters and not any([reports, sources, subtype]):
488
+ print(f"🤖 AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
489
+ try:
490
+ # Import get_available_metadata here to avoid circular imports
491
+ from auditqa.retrieval.filter import get_available_metadata, infer_filters_from_query
492
+
493
+ # Get available metadata
494
+ available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
495
+
496
+ # Infer filters from query - this returns a Qdrant filter
497
+ qdrant_filter, filter_summary = infer_filters_from_query(
498
+ query=query,
499
+ available_metadata=available_metadata,
500
+ llm_client=self.llm_client
501
+ )
502
+
503
+ if qdrant_filter:
504
+ print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
505
+ filters_applied = True
506
+ # Don't set sources/reports/subtype - use the Qdrant filter directly
507
+ else:
508
+ print(f"⚠️ NO QDRANT FILTER: Could not build Qdrant filter from query")
509
+
510
+ except Exception as e:
511
+ print(f"❌ AUTO-INFERENCE FAILED: {e}")
512
+ qdrant_filter = None
513
+ else:
514
+ # Check if any explicit filters were provided
515
+ filters_applied = any([reports, sources, subtype])
516
+ if filters_applied:
517
+ print(f"✅ EXPLICIT FILTERS: Using provided filters")
518
+ else:
519
+ print(f"⚠️ NO FILTERS: No explicit filters and auto-inference disabled")
520
+
521
+ # Extract filter parameters from the filters parameter
522
+ reports = filters.get('reports', []) if filters else []
523
+ sources = filters.get('sources', []) if filters else []
524
+ subtype = filters.get('subtype', []) if filters else []
525
+ year = filters.get('year', []) if filters else []
526
+ district = filters.get('district', []) if filters else []
527
+ filenames = filters.get('filenames', []) if filters else [] # Support mutually exclusive filename filtering
528
+
529
+ # Get vectorstore
530
+ vectorstore = self.vectorstore_manager.get_vectorstore()
531
+ if not vectorstore:
532
+ return PipelineResult(
533
+ answer="Error: Vector store not available",
534
+ sources=[],
535
+ execution_time=0.0,
536
+ metadata={'error': 'Vector store not available'},
537
+ query=query
538
+ )
539
+
540
+ # Initialize context retriever if not already done
541
+ if not hasattr(self, 'context_retriever') or self.context_retriever is None:
542
+ # Get the actual vectorstore object
543
+ vectorstore_obj = self.vectorstore_manager.get_vectorstore()
544
+ if vectorstore_obj is None:
545
+ print("❌ ERROR: Vectorstore is None, cannot initialize ContextRetriever")
546
+ return None
547
+ self.context_retriever = ContextRetriever(vectorstore_obj, self.config)
548
+ print("✅ ContextRetriever initialized successfully")
549
+
550
+ # Debug config access
551
+ print(f" CONFIG DEBUG: Full config keys: {list(self.config.keys()) if isinstance(self.config, dict) else 'Not a dict'}")
552
+ print(f"🔍 CONFIG DEBUG: Retriever config: {self.config.get('retriever', {})}")
553
+ print(f"🔍 CONFIG DEBUG: Retrieval config: {self.config.get('retrieval', {})}")
554
+ print(f"🔍 CONFIG DEBUG: use_reranking from config: {self.config.get('retrieval', {}).get('use_reranking', 'NOT_FOUND')}")
555
+
556
+ # Get the correct top_k value
557
+ # Priority: experiment config > retriever config > default
558
+ top_k = (
559
+ self.config.get('retrieval', {}).get('top_k') or
560
+ self.config.get('retriever', {}).get('top_k') or
561
+ 5
562
+ )
563
+
564
+ # Get reranking setting
565
+ use_reranking = self.config.get('retrieval', {}).get('use_reranking', False)
566
+
567
+ print(f"🔍 CONFIG DEBUG: Final top_k: {top_k}")
568
+ print(f"🔍 CONFIG DEBUG: Final use_reranking: {use_reranking}")
569
+
570
+ # Retrieve context using the context retriever
571
+ context_docs = self.context_retriever.retrieve_context(
572
+ query=query,
573
+ k=top_k,
574
+ reports=reports,
575
+ sources=sources,
576
+ subtype=subtype,
577
+ year=year,
578
+ district=district,
579
+ filenames=filenames,
580
+ use_reranking=use_reranking,
581
+ qdrant_filter=qdrant_filter
582
+ )
583
+
584
+ # Ensure context_docs is not None
585
+ if context_docs is None:
586
+ context_docs = []
587
+
588
+ # Generate answer
589
+ answer = self._generate_answer(self.create_audit_prompt(query, context_docs))
590
+
591
+ execution_time = time.time() - start_time
592
+
593
+ # Create result with comprehensive metadata
594
+ result = PipelineResult(
595
+ answer=answer,
596
+ sources=context_docs,
597
+ execution_time=execution_time,
598
+ metadata={
599
+ 'llm_provider': llm_provider,
600
+ 'use_reranking': use_reranking,
601
+ 'search_mode': search_mode,
602
+ 'search_alpha': search_alpha,
603
+ 'auto_infer_filters': auto_infer_filters,
604
+ 'filters_applied': filters_applied,
605
+ 'with_filtering': filters_applied,
606
+ 'filter_conditions': {
607
+ 'reports': reports,
608
+ 'sources': sources,
609
+ 'subtype': subtype
610
+ },
611
+ 'inferred_filters': inferred_filters,
612
+ 'applied_filters': {
613
+ 'reports': reports,
614
+ 'sources': sources,
615
+ 'subtype': subtype
616
+ },
617
+ # Store filter and reranking metadata
618
+ 'filter_details': {
619
+ 'explicit_filters': {
620
+ 'reports': reports,
621
+ 'sources': sources,
622
+ 'subtype': subtype,
623
+ 'year': year
624
+ },
625
+ 'inferred_filters': inferred_filters if auto_infer_filters else {},
626
+ 'auto_inference_enabled': auto_infer_filters,
627
+ 'qdrant_filter_applied': qdrant_filter is not None,
628
+ 'filter_summary': filter_summary if 'filter_summary' in locals() else None
629
+ },
630
+ 'reranker_model': self._get_reranker_model_name() if use_reranking else None,
631
+ 'reranker_applied': use_reranking,
632
+ 'reranking_info': {
633
+ 'model': self._get_reranker_model_name(),
634
+ 'applied': use_reranking,
635
+ 'top_k': len(context_docs) if context_docs else 0,
636
+ # 'original_documents': [
637
+ # {
638
+ # 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
639
+ # 'metadata': doc.metadata,
640
+ # 'score': getattr(doc, 'score', getattr(doc, 'original_score', 0.0))
641
+ # } for doc in context_docs
642
+ # ] if use_reranking else None,
643
+ 'reranked_documents': [
644
+ {
645
+ 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
646
+ 'metadata': doc.metadata,
647
+ 'score': doc.metadata.get('original_score', getattr(doc, 'score', 0.0)),
648
+ 'original_rank': doc.metadata.get('original_rank', None),
649
+ 'final_rank': doc.metadata.get('final_rank', None),
650
+ 'reranked_score': doc.metadata.get('reranked_score', None)
651
+ } for doc in context_docs
652
+ ] if use_reranking else None
653
+ }
654
+ },
655
+ query=query
656
+ )
657
+
658
+ return result
659
+
660
+ except Exception as e:
661
+ print(f"Error in pipeline run: {e}")
662
+ return PipelineResult(
663
+ answer=f"Error processing query: {e}",
664
+ sources=[],
665
+ execution_time=0.0,
666
+ metadata={'error': str(e)},
667
+ query=query
668
+ )
669
+
670
+
671
+
672
+ def get_system_status(self) -> Dict[str, Any]:
673
+ """
674
+ Get system status information.
675
+
676
+ Returns:
677
+ Dictionary with system status
678
+ """
679
+ status = {
680
+ "config_loaded": bool(self.config),
681
+ "chunks_loaded": bool(self.chunks),
682
+ "vectorstore_connected": bool(
683
+ self.vectorstore_manager and self.vectorstore_manager.get_vectorstore()
684
+ ),
685
+ "components_initialized": bool(
686
+ self.context_retriever and self.report_service
687
+ ),
688
+ }
689
+
690
+ if self.chunks:
691
+ status["num_chunks"] = len(self.chunks)
692
+
693
+ if self.report_service:
694
+ status["available_sources"] = self.report_service.get_available_sources()
695
+ status["available_reports"] = len(
696
+ self.report_service.get_available_reports()
697
+ )
698
+
699
+ status["overall_status"] = (
700
+ "ready"
701
+ if all(
702
+ [
703
+ status["config_loaded"],
704
+ status["chunks_loaded"],
705
+ status["vectorstore_connected"],
706
+ status["components_initialized"],
707
+ ]
708
+ )
709
+ else "not_ready"
710
+ )
711
+
712
+ return status
713
+
714
+ def get_available_llm_providers(self) -> List[str]:
715
+ """Get list of available LLM providers."""
716
+ providers = []
717
+ reader_config = self.config.get("reader", {})
718
+
719
+ for provider in [
720
+ "MISTRAL",
721
+ "OPENAI",
722
+ "OLLAMA",
723
+ "INF_PROVIDERS",
724
+ "NVIDIA",
725
+ "DEDICATED",
726
+ "OPENROUTER",
727
+ ]:
728
+ if provider in reader_config:
729
+ providers.append(provider.lower())
730
+
731
+ return providers
src/reporting/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Report metadata and utilities."""
2
+
3
+ from .metadata import get_report_metadata, get_available_sources
4
+ from .service import ReportService
5
+
6
+ __all__ = ["get_report_metadata", "get_available_sources", "ReportService"]
src/reporting/feedback_schema.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Schema for RAG Chatbot
3
+
4
+ This module defines dataclasses for feedback data structures
5
+ and provides Snowflake schema generation.
6
+ """
7
+
8
+ from dataclasses import dataclass, asdict, field
9
+ from typing import List, Optional, Dict, Any, Union
10
+ from datetime import datetime
11
+
12
+
13
+ @dataclass
14
+ class RetrievedDocument:
15
+ """Single retrieved document metadata"""
16
+ doc_id: str
17
+ filename: str
18
+ page: int
19
+ score: float
20
+ content: str
21
+ metadata: Dict[str, Any]
22
+
23
+
24
+ @dataclass
25
+ class RetrievalEntry:
26
+ """Single retrieval operation metadata"""
27
+ rag_query: str
28
+ documents_retrieved: List[RetrievedDocument]
29
+ conversation_length: int
30
+ filters_applied: Optional[Dict[str, Any]] = None
31
+ timestamp: Optional[float] = None
32
+ _raw_data: Optional[Dict[str, Any]] = None
33
+
34
+
35
+ @dataclass
36
+ class UserFeedback:
37
+ """User feedback submission data"""
38
+ feedback_id: str
39
+ open_ended_feedback: Optional[str]
40
+ score: int
41
+ is_feedback_about_last_retrieval: bool
42
+ retrieved_data: List[RetrievalEntry]
43
+ conversation_id: str
44
+ timestamp: float
45
+ message_count: int
46
+ has_retrievals: bool
47
+ retrieval_count: int
48
+ user_query: Optional[str] = None
49
+ bot_response: Optional[str] = None
50
+ created_at: str = field(default_factory=lambda: datetime.now().isoformat())
51
+
52
+ def to_dict(self) -> Dict[str, Any]:
53
+ """Convert to dictionary with nested data structures"""
54
+ result = asdict(self)
55
+ # Handle nested objects
56
+ if self.retrieved_data:
57
+ result['retrieved_data'] = [self._serialize_retrieval_entry(entry) for entry in self.retrieved_data]
58
+ return result
59
+
60
+ def _serialize_retrieval_entry(self, entry: RetrievalEntry) -> Dict[str, Any]:
61
+ """Serialize retrieval entry to dict"""
62
+ # If raw data exists, use it (it's already properly formatted)
63
+ if hasattr(entry, '_raw_data') and entry._raw_data:
64
+ return entry._raw_data
65
+
66
+ # Otherwise, serialize the dataclass
67
+ result = asdict(entry)
68
+ if entry.documents_retrieved:
69
+ result['documents_retrieved'] = [asdict(doc) for doc in entry.documents_retrieved]
70
+ return result
71
+
72
+ def to_snowflake_schema(self) -> Dict[str, Any]:
73
+ """Generate Snowflake schema for this dataclass"""
74
+ schema = {
75
+ "feedback_id": "VARCHAR(255)",
76
+ "open_ended_feedback": "VARCHAR(16777216)", # Large text
77
+ "score": "INTEGER",
78
+ "is_feedback_about_last_retrieval": "BOOLEAN",
79
+ "conversation_id": "VARCHAR(255)",
80
+ "timestamp": "NUMBER(20, 0)",
81
+ "message_count": "INTEGER",
82
+ "has_retrievals": "BOOLEAN",
83
+ "retrieval_count": "INTEGER",
84
+ "user_query": "VARCHAR(16777216)",
85
+ "bot_response": "VARCHAR(16777216)",
86
+ "created_at": "TIMESTAMP_NTZ",
87
+ "retrieved_data": "VARIANT", # Array of retrieval entries
88
+ # retrieved_data structure:
89
+ # [
90
+ # {
91
+ # "rag_query": "...",
92
+ # "conversation_length": 5,
93
+ # "timestamp": 1234567890,
94
+ # "docs_retrieved": [
95
+ # {"filename": "...", "page": 14, "score": 0.95, ...},
96
+ # ...
97
+ # ]
98
+ # },
99
+ # ...
100
+ # ]
101
+ }
102
+ return schema
103
+
104
+ @classmethod
105
+ def get_snowflake_create_table_sql(cls, table_name: str = "user_feedback") -> str:
106
+ """Generate CREATE TABLE SQL for Snowflake"""
107
+ schema = cls.to_snowflake_schema(None)
108
+
109
+ columns = []
110
+ for col_name, col_type in schema.items():
111
+ nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
112
+ columns.append(f" {col_name} {col_type} {nullable}")
113
+
114
+ # Build SQL string properly
115
+ columns_str = ",\n".join(columns)
116
+
117
+ sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
118
+ {columns_str},
119
+ PRIMARY KEY (feedback_id)
120
+ );
121
+
122
+ -- Create index on timestamp for querying by time
123
+ CREATE INDEX IF NOT EXISTS idx_feedback_timestamp ON {table_name} (timestamp);
124
+
125
+ -- Create index on conversation_id for querying by conversation
126
+ CREATE INDEX IF NOT EXISTS idx_feedback_conversation ON {table_name} (conversation_id);
127
+
128
+ -- Create index on score for feedback analysis
129
+ CREATE INDEX IF NOT EXISTS idx_feedback_score ON {table_name} (score);
130
+ """
131
+ return sql
132
+
133
+
134
+ # Snowflake variant schema for retrieved_data array
135
+ RETRIEVAL_ENTRY_SCHEMA = {
136
+ "rag_query": "VARCHAR",
137
+ "documents_retrieved": "ARRAY", # Array of document objects
138
+ "conversation_length": "INTEGER",
139
+ "filters_applied": "OBJECT",
140
+ "timestamp": "NUMBER"
141
+ }
142
+
143
+ DOCUMENT_SCHEMA = {
144
+ "doc_id": "VARCHAR",
145
+ "filename": "VARCHAR",
146
+ "page": "INTEGER",
147
+ "score": "DOUBLE",
148
+ "content": "VARCHAR(16777216)",
149
+ "metadata": "OBJECT"
150
+ }
151
+
152
+
153
+ def generate_snowflake_schema_sql() -> str:
154
+ """Generate complete Snowflake schema SQL for feedback system"""
155
+ return UserFeedback.get_snowflake_create_table_sql("user_feedback")
156
+
157
+
158
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
159
+ """Create UserFeedback instance from dictionary"""
160
+ # Parse retrieved_data if present
161
+ retrieved_data = []
162
+ if "retrieved_data" in data and data["retrieved_data"]:
163
+ for entry_dict in data.get("retrieved_data", []):
164
+ # Map the actual structure from rag_retrieval_history
165
+ # Entry has: conversation_up_to, rag_query_expansion, docs_retrieved
166
+ try:
167
+ # Try to map to expected structure
168
+ entry = RetrievalEntry(
169
+ rag_query=entry_dict.get("rag_query_expansion", ""),
170
+ documents_retrieved=[], # Empty for now, will store as raw data
171
+ conversation_length=len(entry_dict.get("conversation_up_to", [])),
172
+ filters_applied=None,
173
+ timestamp=entry_dict.get("timestamp", None)
174
+ )
175
+ # Store raw data in the entry
176
+ entry._raw_data = entry_dict # Store original for preservation
177
+ retrieved_data.append(entry)
178
+ except Exception as e:
179
+ # If mapping fails, store as-is without strict typing
180
+ pass
181
+
182
+ return UserFeedback(
183
+ feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
184
+ open_ended_feedback=data.get("open_ended_feedback"),
185
+ score=data["score"],
186
+ is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
187
+ retrieved_data=retrieved_data,
188
+ conversation_id=data["conversation_id"],
189
+ timestamp=data["timestamp"],
190
+ message_count=data["message_count"],
191
+ has_retrievals=data["has_retrievals"],
192
+ retrieval_count=data["retrieval_count"],
193
+ user_query=data.get("user_query"),
194
+ bot_response=data.get("bot_response")
195
+ )
196
+
src/reporting/metadata.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Report metadata management."""
2
+
3
+ from typing import Dict, List, Any, Set
4
+ from pathlib import Path
5
+
6
+
7
+ def get_report_metadata(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
8
+ """
9
+ Extract metadata from chunks.
10
+
11
+ Args:
12
+ chunks: List of chunk dictionaries
13
+
14
+ Returns:
15
+ Dictionary with report metadata
16
+ """
17
+ if not chunks:
18
+ return {}
19
+
20
+ sources = set()
21
+ filenames = set()
22
+ years = set()
23
+
24
+ for chunk in chunks:
25
+ metadata = chunk.get("metadata", {})
26
+
27
+ if "source" in metadata:
28
+ sources.add(metadata["source"])
29
+
30
+ if "filename" in metadata:
31
+ filenames.add(metadata["filename"])
32
+
33
+ if "year" in metadata:
34
+ years.add(metadata["year"])
35
+
36
+ return {
37
+ "sources": sorted(list(sources)),
38
+ "filenames": sorted(list(filenames)),
39
+ "years": sorted(list(years)),
40
+ "total_chunks": len(chunks)
41
+ }
42
+
43
+
44
+ def get_available_sources() -> List[str]:
45
+ """
46
+ Get list of available report sources (legacy compatibility).
47
+
48
+ Returns:
49
+ List of source categories
50
+ """
51
+ # This would typically come from the original auditqa_old.reports module
52
+ # For now, return common categories
53
+ return [
54
+ "Consolidated",
55
+ "Ministry, Department, Agency and Projects",
56
+ "Local Government",
57
+ "Value for Money",
58
+ "Thematic",
59
+ "Hospital",
60
+ "Project"
61
+ ]
62
+
63
+
64
+ def get_source_subtypes() -> Dict[str, List[str]]:
65
+ """
66
+ Get mapping of sources to their subtypes (placeholder).
67
+
68
+ Returns:
69
+ Dictionary mapping sources to subtypes
70
+ """
71
+ # This was originally imported from auditqa_old.reports.new_files
72
+ # For now, return a placeholder structure
73
+ return {
74
+ "Consolidated": ["Annual Consolidated OAG 2024", "Annual Consolidated OAG 2023"],
75
+ "Local Government": ["District Reports", "Municipal Reports"],
76
+ "Ministry, Department, Agency and Projects": ["Ministry Reports", "Agency Reports"],
77
+ "Value for Money": ["VFM Reports 2024", "VFM Reports 2023"],
78
+ "Thematic": ["Thematic Reports 2024", "Thematic Reports 2023"],
79
+ "Hospital": ["Hospital Reports 2024", "Hospital Reports 2023"],
80
+ "Project": ["Project Reports 2024", "Project Reports 2023"]
81
+ }
82
+
83
+
84
+ def validate_report_filters(
85
+ reports: List[str] = None,
86
+ sources: str = None,
87
+ subtype: List[str] = None,
88
+ available_metadata: Dict[str, Any] = None
89
+ ) -> Dict[str, Any]:
90
+ """
91
+ Validate report filter parameters.
92
+
93
+ Args:
94
+ reports: List of specific report filenames
95
+ sources: Source category
96
+ subtype: List of subtypes
97
+ available_metadata: Available metadata for validation
98
+
99
+ Returns:
100
+ Dictionary with validation results
101
+ """
102
+ validation_result = {
103
+ "valid": True,
104
+ "warnings": [],
105
+ "errors": []
106
+ }
107
+
108
+ if not available_metadata:
109
+ validation_result["warnings"].append("No metadata available for validation")
110
+ return validation_result
111
+
112
+ available_sources = available_metadata.get("sources", [])
113
+ available_filenames = available_metadata.get("filenames", [])
114
+
115
+ # Validate sources
116
+ if sources and sources not in available_sources:
117
+ validation_result["errors"].append(f"Source '{sources}' not found in available sources")
118
+ validation_result["valid"] = False
119
+
120
+ # Validate reports
121
+ if reports:
122
+ for report in reports:
123
+ if report not in available_filenames:
124
+ validation_result["warnings"].append(f"Report '{report}' not found in available reports")
125
+
126
+ # Validate subtypes
127
+ if subtype:
128
+ for sub in subtype:
129
+ if sub not in available_filenames:
130
+ validation_result["warnings"].append(f"Subtype '{sub}' not found in available reports")
131
+
132
+ return validation_result
133
+
134
+
135
+ def get_report_statistics(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
136
+ """
137
+ Get statistics about reports in chunks.
138
+
139
+ Args:
140
+ chunks: List of chunk dictionaries
141
+
142
+ Returns:
143
+ Dictionary with report statistics
144
+ """
145
+ if not chunks:
146
+ return {}
147
+
148
+ stats = {
149
+ "total_chunks": len(chunks),
150
+ "sources": {},
151
+ "years": {},
152
+ "avg_chunk_length": 0,
153
+ "total_content_length": 0
154
+ }
155
+
156
+ total_length = 0
157
+
158
+ for chunk in chunks:
159
+ content = chunk.get("content", "")
160
+ total_length += len(content)
161
+
162
+ metadata = chunk.get("metadata", {})
163
+
164
+ # Count by source
165
+ source = metadata.get("source", "Unknown")
166
+ stats["sources"][source] = stats["sources"].get(source, 0) + 1
167
+
168
+ # Count by year
169
+ year = metadata.get("year", "Unknown")
170
+ stats["years"][year] = stats["years"].get(year, 0) + 1
171
+
172
+ stats["total_content_length"] = total_length
173
+ stats["avg_chunk_length"] = total_length / len(chunks) if chunks else 0
174
+
175
+ return stats
176
+
177
+
178
+ def filter_chunks_by_metadata(
179
+ chunks: List[Dict[str, Any]],
180
+ source_filter: str = None,
181
+ filename_filter: List[str] = None,
182
+ year_filter: List[str] = None
183
+ ) -> List[Dict[str, Any]]:
184
+ """
185
+ Filter chunks by metadata criteria.
186
+
187
+ Args:
188
+ chunks: List of chunk dictionaries
189
+ source_filter: Source to filter by
190
+ filename_filter: List of filenames to filter by
191
+ year_filter: List of years to filter by
192
+
193
+ Returns:
194
+ Filtered list of chunks
195
+ """
196
+ filtered_chunks = chunks
197
+
198
+ if source_filter:
199
+ filtered_chunks = [
200
+ chunk for chunk in filtered_chunks
201
+ if chunk.get("metadata", {}).get("source") == source_filter
202
+ ]
203
+
204
+ if filename_filter:
205
+ filtered_chunks = [
206
+ chunk for chunk in filtered_chunks
207
+ if chunk.get("metadata", {}).get("filename") in filename_filter
208
+ ]
209
+
210
+ if year_filter:
211
+ filtered_chunks = [
212
+ chunk for chunk in filtered_chunks
213
+ if chunk.get("metadata", {}).get("year") in year_filter
214
+ ]
215
+
216
+ return filtered_chunks
src/reporting/service.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Report service for managing report operations."""
2
+
3
+ from typing import Dict, List, Any, Optional
4
+ from .metadata import get_report_metadata, get_available_sources, get_source_subtypes
5
+
6
+
7
+ class ReportService:
8
+ """Service class for report operations."""
9
+
10
+ def __init__(self, chunks: List[Dict[str, Any]] = None):
11
+ """
12
+ Initialize report service.
13
+
14
+ Args:
15
+ chunks: List of chunk dictionaries
16
+ """
17
+ self.chunks = chunks or []
18
+ self.metadata = get_report_metadata(self.chunks) if self.chunks else {}
19
+
20
+ def get_available_sources(self) -> List[str]:
21
+ """Get available report sources."""
22
+ if self.metadata:
23
+ return self.metadata.get("sources", [])
24
+ return get_available_sources()
25
+
26
+ def get_available_reports(self) -> List[str]:
27
+ """Get available report filenames."""
28
+ return self.metadata.get("filenames", [])
29
+
30
+ def get_source_subtypes(self) -> Dict[str, List[str]]:
31
+ """Get source to subtype mapping."""
32
+ # For now, use the placeholder function
33
+ # In a full implementation, this would be derived from actual data
34
+ return get_source_subtypes()
35
+
36
+ def get_reports_by_source(self, source: str) -> List[str]:
37
+ """
38
+ Get reports filtered by source.
39
+
40
+ Args:
41
+ source: Source category
42
+
43
+ Returns:
44
+ List of report filenames
45
+ """
46
+ if not self.chunks:
47
+ return []
48
+
49
+ reports = set()
50
+ for chunk in self.chunks:
51
+ metadata = chunk.get("metadata", {})
52
+ if metadata.get("source") == source:
53
+ filename = metadata.get("filename")
54
+ if filename:
55
+ reports.add(filename)
56
+
57
+ return sorted(list(reports))
58
+
59
+ def get_years_by_source(self, source: str) -> List[str]:
60
+ """
61
+ Get years available for a specific source.
62
+
63
+ Args:
64
+ source: Source category
65
+
66
+ Returns:
67
+ List of years
68
+ """
69
+ if not self.chunks:
70
+ return []
71
+
72
+ years = set()
73
+ for chunk in self.chunks:
74
+ metadata = chunk.get("metadata", {})
75
+ if metadata.get("source") == source:
76
+ year = metadata.get("year")
77
+ if year:
78
+ years.add(year)
79
+
80
+ return sorted(list(years))
81
+
82
+ def search_reports(self, query: str) -> List[str]:
83
+ """
84
+ Search for reports by name.
85
+
86
+ Args:
87
+ query: Search query
88
+
89
+ Returns:
90
+ List of matching report filenames
91
+ """
92
+ if not self.chunks:
93
+ return []
94
+
95
+ query_lower = query.lower()
96
+ matching_reports = set()
97
+
98
+ for chunk in self.chunks:
99
+ metadata = chunk.get("metadata", {})
100
+ filename = metadata.get("filename", "")
101
+
102
+ if query_lower in filename.lower():
103
+ matching_reports.add(filename)
104
+
105
+ return sorted(list(matching_reports))
106
+
107
+ def get_report_info(self, filename: str) -> Dict[str, Any]:
108
+ """
109
+ Get information about a specific report.
110
+
111
+ Args:
112
+ filename: Report filename
113
+
114
+ Returns:
115
+ Dictionary with report information
116
+ """
117
+ if not self.chunks:
118
+ return {}
119
+
120
+ report_info = {
121
+ "filename": filename,
122
+ "chunk_count": 0,
123
+ "sources": set(),
124
+ "years": set(),
125
+ "total_content_length": 0
126
+ }
127
+
128
+ for chunk in self.chunks:
129
+ metadata = chunk.get("metadata", {})
130
+ if metadata.get("filename") == filename:
131
+ report_info["chunk_count"] += 1
132
+ report_info["total_content_length"] += len(chunk.get("content", ""))
133
+
134
+ if "source" in metadata:
135
+ report_info["sources"].add(metadata["source"])
136
+
137
+ if "year" in metadata:
138
+ report_info["years"].add(metadata["year"])
139
+
140
+ # Convert sets to lists
141
+ report_info["sources"] = list(report_info["sources"])
142
+ report_info["years"] = list(report_info["years"])
143
+
144
+ return report_info
src/reporting/snowflake_connector.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Snowflake Connector for Feedback System
3
+
4
+ This module handles inserting user feedback into Snowflake.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ from typing import Dict, Any, Optional
11
+ from src.reporting.feedback_schema import UserFeedback
12
+
13
+ # Try to import snowflake connector
14
+ try:
15
+ import snowflake.connector
16
+ SNOWFLAKE_AVAILABLE = True
17
+ except ImportError:
18
+ SNOWFLAKE_AVAILABLE = False
19
+ logging.warning("⚠️ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SnowflakeFeedbackConnector:
27
+ """Connector for inserting feedback into Snowflake"""
28
+
29
+ def __init__(
30
+ self,
31
+ user: str,
32
+ password: str,
33
+ account: str,
34
+ warehouse: str,
35
+ database: str = "SNOWFLAKE_LEARNING",
36
+ schema: str = "PUBLIC"
37
+ ):
38
+ self.user = user
39
+ self.password = password
40
+ self.account = account
41
+ self.warehouse = warehouse
42
+ self.database = database
43
+ self.schema = schema
44
+ self._connection = None
45
+
46
+ def connect(self):
47
+ """Establish Snowflake connection"""
48
+ if not SNOWFLAKE_AVAILABLE:
49
+ raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
50
+
51
+ logger.info("=" * 80)
52
+ logger.info("🔌 SNOWFLAKE CONNECTION: Attempting to connect...")
53
+ logger.info(f" - Account: {self.account}")
54
+ logger.info(f" - Warehouse: {self.warehouse}")
55
+ logger.info(f" - Database: {self.database}")
56
+ logger.info(f" - Schema: {self.schema}")
57
+ logger.info(f" - User: {self.user}")
58
+
59
+ try:
60
+ self._connection = snowflake.connector.connect(
61
+ user=self.user,
62
+ password=self.password,
63
+ account=self.account,
64
+ warehouse=self.warehouse
65
+ # Don't set database/schema in connection - we'll do it per query
66
+ )
67
+ logger.info("✅ SNOWFLAKE CONNECTION: Successfully connected")
68
+ logger.info("=" * 80)
69
+ print(f"✅ Connected to Snowflake: {self.database}.{self.schema}")
70
+ except Exception as e:
71
+ logger.error(f"❌ SNOWFLAKE CONNECTION FAILED: {e}")
72
+ logger.error("=" * 80)
73
+ print(f"❌ Failed to connect to Snowflake: {e}")
74
+ raise
75
+
76
+ def disconnect(self):
77
+ """Close Snowflake connection"""
78
+ if self._connection:
79
+ self._connection.close()
80
+ print("✅ Disconnected from Snowflake")
81
+
82
+ def insert_feedback(self, feedback: UserFeedback) -> bool:
83
+ """Insert a single feedback record into Snowflake"""
84
+ logger.info("=" * 80)
85
+ logger.info("🔄 SNOWFLAKE INSERT: Starting feedback insertion process")
86
+ logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
87
+
88
+ if not self._connection:
89
+ logger.error("❌ Not connected to Snowflake. Call connect() first.")
90
+ raise RuntimeError("Not connected to Snowflake. Call connect() first.")
91
+
92
+ try:
93
+ logger.info("📊 VALIDATION: Validating feedback data structure...")
94
+
95
+ # Validate feedback object
96
+ validation_errors = []
97
+ if not feedback.feedback_id:
98
+ validation_errors.append("Missing feedback_id")
99
+ if feedback.score is None:
100
+ validation_errors.append("Missing score")
101
+ if feedback.timestamp is None:
102
+ validation_errors.append("Missing timestamp")
103
+
104
+ if validation_errors:
105
+ logger.error(f"❌ VALIDATION FAILED: {validation_errors}")
106
+ return False
107
+ else:
108
+ logger.info("✅ VALIDATION PASSED: All required fields present")
109
+
110
+ logger.info("📋 Data Summary:")
111
+ logger.info(f" - Feedback ID: {feedback.feedback_id}")
112
+ logger.info(f" - Score: {feedback.score}")
113
+ logger.info(f" - Conversation ID: {feedback.conversation_id}")
114
+ logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
115
+ logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
116
+ logger.info(f" - Message Count: {feedback.message_count}")
117
+ logger.info(f" - Timestamp: {feedback.timestamp}")
118
+
119
+ cursor = self._connection.cursor()
120
+ logger.info("✅ SNOWFLAKE CONNECTION: Cursor created")
121
+
122
+ # Set database and schema context
123
+ logger.info(f"🔧 SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
124
+ try:
125
+ cursor.execute(f'USE DATABASE "{self.database}"')
126
+ cursor.execute(f'USE SCHEMA "{self.schema}"')
127
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
128
+ current_db, current_schema = cursor.fetchone()
129
+ logger.info(f"✅ Current context verified: Database={current_db}, Schema={current_schema}")
130
+ except Exception as e:
131
+ logger.error(f"❌ Could not set context: {e}")
132
+ raise
133
+
134
+ # Prepare data
135
+ logger.info("🔧 DATA PREPARATION: Preparing retrieved_data...")
136
+ retrieved_data_raw = feedback.to_dict()['retrieved_data']
137
+
138
+ logger.info(f" - Retrieved data type (raw): {type(retrieved_data_raw).__name__}")
139
+ logger.info(f" - Retrieved data: {repr(retrieved_data_raw)[:200]}")
140
+
141
+ # If retrieved_data is already a string (from UI), parse it
142
+ if isinstance(retrieved_data_raw, str):
143
+ logger.info(" - Parsing string to Python object")
144
+ retrieved_data = json.loads(retrieved_data_raw)
145
+ elif retrieved_data_raw is None:
146
+ retrieved_data = None
147
+ else:
148
+ # It's already a Python object (list/dict)
149
+ logger.info(" - Data is already a Python object")
150
+ retrieved_data = retrieved_data_raw
151
+
152
+ logger.info(f" - Retrieved data size: {len(str(retrieved_data)) if retrieved_data else 0} characters")
153
+ logger.info(f" - Retrieved data type: {type(retrieved_data).__name__}")
154
+
155
+ # Convert to JSON string for TEXT column
156
+ if retrieved_data:
157
+ retrieved_data_for_db = json.dumps(retrieved_data)
158
+ logger.info(f" - Converting to JSON string for TEXT column")
159
+ logger.info(f" - JSON string length: {len(retrieved_data_for_db)}")
160
+ else:
161
+ logger.info(f" - Retrieved data is None, using NULL")
162
+ retrieved_data_for_db = None
163
+
164
+ # Build SQL with retrieved_data as a TEXT column parameter
165
+ sql = f"""INSERT INTO user_feedback (
166
+ feedback_id,
167
+ open_ended_feedback,
168
+ score,
169
+ is_feedback_about_last_retrieval,
170
+ conversation_id,
171
+ timestamp,
172
+ message_count,
173
+ has_retrievals,
174
+ retrieval_count,
175
+ user_query,
176
+ bot_response,
177
+ created_at,
178
+ retrieved_data
179
+ ) VALUES (
180
+ %(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
181
+ %(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
182
+ %(retrieval_count)s, %(user_query)s, %(bot_response)s, %(created_at)s,
183
+ %(retrieved_data)s
184
+ )"""
185
+
186
+ logger.info("📝 SQL PREPARATION: Building INSERT statement...")
187
+ logger.info(f" - Target table: user_feedback")
188
+ logger.info(f" - Database: {self.database}")
189
+ logger.info(f" - Schema: {self.schema}")
190
+
191
+ # Prepare parameters
192
+ params = {
193
+ 'feedback_id': feedback.feedback_id,
194
+ 'open_ended_feedback': feedback.open_ended_feedback,
195
+ 'score': feedback.score,
196
+ 'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
197
+ 'conversation_id': feedback.conversation_id,
198
+ 'timestamp': int(feedback.timestamp),
199
+ 'message_count': feedback.message_count,
200
+ 'has_retrievals': feedback.has_retrievals,
201
+ 'retrieval_count': feedback.retrieval_count,
202
+ 'user_query': feedback.user_query,
203
+ 'bot_response': feedback.bot_response,
204
+ 'created_at': feedback.created_at,
205
+ 'retrieved_data': retrieved_data_for_db
206
+ }
207
+
208
+ # Execute insert
209
+ logger.info("🚀 SQL EXECUTION: Executing INSERT query...")
210
+ cursor.execute(sql, params)
211
+
212
+ logger.info("✅ SQL EXECUTION: Query executed successfully")
213
+ logger.info(f" - Rows affected: 1")
214
+ logger.info(f" - Status: SUCCESS")
215
+
216
+ cursor.close()
217
+ logger.info("✅ SNOWFLAKE INSERT: Feedback inserted successfully")
218
+ logger.info(f"📝 Inserted feedback: {feedback.feedback_id}")
219
+ logger.info("=" * 80)
220
+ return True
221
+
222
+ except Exception as e:
223
+ # Check if it's a Snowflake error
224
+ if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
225
+ logger.error(f"❌ SQL EXECUTION ERROR: {e}")
226
+ logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
227
+ logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
228
+ else:
229
+ logger.error(f"❌ SNOWFLAKE INSERT FAILED: {type(e).__name__}")
230
+ logger.error(f" - Error: {e}")
231
+ logger.error("=" * 80)
232
+ return False
233
+
234
+ def __enter__(self):
235
+ """Context manager entry"""
236
+ self.connect()
237
+ return self
238
+
239
+ def __exit__(self, exc_type, exc_val, exc_tb):
240
+ """Context manager exit"""
241
+ self.disconnect()
242
+
243
+
244
+ def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
245
+ """Create Snowflake connector from environment variables"""
246
+ user = os.getenv("SNOWFLAKE_USER")
247
+ password = os.getenv("SNOWFLAKE_PASSWORD")
248
+ account = os.getenv("SNOWFLAKE_ACCOUNT")
249
+ warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
250
+ database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
251
+ schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
252
+
253
+ if not all([user, password, account, warehouse]):
254
+ print("⚠️ Snowflake credentials not found in environment variables")
255
+ print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
256
+ return None
257
+
258
+ return SnowflakeFeedbackConnector(
259
+ user=user,
260
+ password=password,
261
+ account=account,
262
+ warehouse=warehouse,
263
+ database=database,
264
+ schema=schema
265
+ )
266
+
267
+
268
+ def save_to_snowflake(feedback: UserFeedback) -> bool:
269
+ """Helper function to save feedback to Snowflake"""
270
+ logger.info("=" * 80)
271
+ logger.info("🔵 SNOWFLAKE SAVE: Starting save process")
272
+ logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
273
+
274
+ connector = get_snowflake_connector_from_env()
275
+
276
+ if not connector:
277
+ logger.warning("⚠️ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
278
+ logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
279
+ logger.info("=" * 80)
280
+ return False
281
+
282
+ try:
283
+ logger.info("📡 SNOWFLAKE SAVE: Establishing connection...")
284
+ connector.connect()
285
+ logger.info("✅ SNOWFLAKE SAVE: Connection established")
286
+
287
+ logger.info("📥 SNOWFLAKE SAVE: Attempting to insert feedback...")
288
+ success = connector.insert_feedback(feedback)
289
+
290
+ logger.info("🔌 SNOWFLAKE SAVE: Disconnecting...")
291
+ connector.disconnect()
292
+
293
+ if success:
294
+ logger.info("✅ SNOWFLAKE SAVE: Successfully saved feedback")
295
+ else:
296
+ logger.error("❌ SNOWFLAKE SAVE: Failed to save feedback")
297
+
298
+ logger.info("=" * 80)
299
+ return success
300
+ except Exception as e:
301
+ logger.error(f"❌ SNOWFLAKE SAVE ERROR: {type(e).__name__}")
302
+ logger.error(f" - Error: {e}")
303
+ logger.info("=" * 80)
304
+ return False
305
+
src/retrieval/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document retrieval and filtering utilities."""
2
+
3
+ from .filter import create_filter, FilterBuilder
4
+ from .context import ContextRetriever, get_context
5
+ from .hybrid import HybridRetriever, get_available_search_modes, get_search_mode_description
6
+
7
+ __all__ = [
8
+ "create_filter",
9
+ "FilterBuilder",
10
+ "ContextRetriever",
11
+ "get_context",
12
+ "HybridRetriever",
13
+ "get_available_search_modes",
14
+ "get_search_mode_description"
15
+ ]
src/retrieval/colbert_cache.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ColBERT embeddings cache for test set documents.
3
+ Provides O(1) lookup for ColBERT embeddings during late interaction.
4
+ """
5
+
6
+ import json
7
+ import numpy as np
8
+ from pathlib import Path
9
+ from typing import Dict, Optional, Any
10
+
11
+
12
+ class ColBERTCache:
13
+ """Cache for ColBERT embeddings of test set documents."""
14
+
15
+ def __init__(self, cache_file: str = "test_set_colbert_cache.json"):
16
+ self.cache_file = Path("outputs/caches") / cache_file
17
+ self.embeddings_cache: Dict[str, np.ndarray] = {}
18
+ self._load_cache()
19
+
20
+ def _load_cache(self):
21
+ """Load embeddings from cache file."""
22
+ if not self.cache_file.exists():
23
+ print(f"⚠️ ColBERT cache not found: {self.cache_file}")
24
+ print("💡 Run 'python precalculate_test_set_colbert.py' to create cache")
25
+ return
26
+
27
+ print(f"📂 Loading ColBERT cache from {self.cache_file}...")
28
+
29
+ try:
30
+ with open(self.cache_file, 'r') as f:
31
+ cache_data = json.load(f)
32
+
33
+ # Reconstruct embeddings from compressed format
34
+ for doc_id, data in cache_data.items():
35
+ embedding_min = data['min']
36
+ embedding_max = data['max']
37
+ quantized_embedding = np.array(data['embedding'], dtype=np.uint8)
38
+
39
+ # Reconstruct original embedding
40
+ reconstructed = (quantized_embedding.astype(np.float32) / 255.0) * (embedding_max - embedding_min) + embedding_min
41
+ self.embeddings_cache[doc_id] = reconstructed.reshape(data['shape'])
42
+
43
+ print(f"✅ Loaded {len(self.embeddings_cache)} ColBERT embeddings from cache")
44
+
45
+ except Exception as e:
46
+ print(f"❌ Error loading ColBERT cache: {e}")
47
+ self.embeddings_cache = {}
48
+
49
+ def get_embedding(self, document_text: str) -> Optional[np.ndarray]:
50
+ """Get ColBERT embedding for a document (O(1) lookup)."""
51
+ return self.embeddings_cache.get(document_text)
52
+
53
+ def has_embedding(self, document_text: str) -> bool:
54
+ """Check if embedding exists for document."""
55
+ return document_text in self.embeddings_cache
56
+
57
+ def get_cache_stats(self) -> Dict[str, Any]:
58
+ """Get cache statistics."""
59
+ return {
60
+ 'total_embeddings': len(self.embeddings_cache),
61
+ 'cache_file': str(self.cache_file),
62
+ 'cache_exists': self.cache_file.exists()
63
+ }
64
+
65
+
66
+ # Global cache instance
67
+ _colbert_cache = None
68
+
69
+ def get_colbert_cache() -> ColBERTCache:
70
+ """Get global ColBERT cache instance."""
71
+ global _colbert_cache
72
+ if _colbert_cache is None:
73
+ _colbert_cache = ColBERTCache()
74
+ return _colbert_cache
src/retrieval/context.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Context retrieval with reranking capabilities."""
2
+
3
+ import os
4
+ from typing import List, Optional, Tuple, Dict, Any
5
+ from langchain.schema import Document
6
+ from langchain_community.vectorstores import Qdrant
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from sentence_transformers import CrossEncoder
9
+ import numpy as np
10
+ import torch
11
+ from qdrant_client.http import models as rest
12
+ import traceback
13
+
14
+ from .filter import create_filter
15
+
16
+ class ContextRetriever:
17
+ """
18
+ Context retriever for hybrid search with optional filtering and reranking.
19
+ """
20
+
21
+ def __init__(self, vectorstore: Qdrant, config: dict = None):
22
+ """
23
+ Initialize the context retriever.
24
+
25
+ Args:
26
+ vectorstore: Qdrant vector store instance
27
+ config: Configuration dictionary
28
+ """
29
+ self.vectorstore = vectorstore
30
+ self.config = config or {}
31
+ self.reranker = None
32
+
33
+ # BM25 attributes
34
+ self.bm25_vectorizer = None
35
+ self.bm25_matrix = None
36
+ self.bm25_documents = None
37
+
38
+ # Initialize reranker if available
39
+ # Try to get reranker model from different config paths
40
+ self.reranker_model_name = (
41
+ config.get('retrieval', {}).get('reranker_model') or
42
+ config.get('ranker', {}).get('model') or
43
+ config.get('reranker_model') or
44
+ 'BAAI/bge-reranker-v2-m3'
45
+ )
46
+ self.reranker_type = self._detect_reranker_type(self.reranker_model_name)
47
+
48
+ try:
49
+ if self.reranker_type == 'colbert':
50
+ from colbert.infra import Run, ColBERTConfig
51
+ from colbert.modeling.checkpoint import Checkpoint
52
+ # ColBERT uses late interaction - different implementation needed
53
+ print(f"✅ RERANKER: ColBERT model detected ({self.reranker_model_name})")
54
+ print(f"🔍 INTERACTION TYPE: Late interaction (token-level embeddings)")
55
+
56
+ # Create ColBERT config for CPU mode
57
+ colbert_config = ColBERTConfig(
58
+ doc_maxlen=300,
59
+ query_maxlen=32,
60
+ nbits=2,
61
+ kmeans_niters=4,
62
+ root="./colbert_data"
63
+ )
64
+
65
+ # Load checkpoint (e.g. "colbert-ir/colbertv2.0")
66
+ self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config)
67
+ self.colbert_model = self.colbert_checkpoint.model
68
+ self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer
69
+ self.reranker = self._colbert_rerank # attach wrapper function
70
+ print(f"✅ COLBERT: Model and tokenizer loaded successfully")
71
+
72
+ else:
73
+ # Standard CrossEncoder for BGE and other models
74
+ from sentence_transformers import CrossEncoder
75
+ self.reranker = CrossEncoder(self.reranker_model_name)
76
+ print(f"✅ RERANKER: Initialized {self.reranker_model_name}")
77
+ print(f"🔍 INTERACTION TYPE: Cross-encoder (single relevance score)")
78
+ except Exception as e:
79
+ print(f"⚠️ Reranker initialization failed: {e}")
80
+ self.reranker = None
81
+
82
+ def _detect_reranker_type(self, model_name: str) -> str:
83
+ """
84
+ Detect the type of reranker based on model name.
85
+
86
+ Args:
87
+ model_name: Name of the reranker model
88
+
89
+ Returns:
90
+ 'colbert' for ColBERT models, 'crossencoder' for others
91
+ """
92
+ model_name_lower = model_name.lower()
93
+
94
+ # ColBERT model patterns
95
+ colbert_patterns = [
96
+ 'colbert',
97
+ 'colbert-ir',
98
+ 'colbertv2',
99
+ 'colbert-v2'
100
+ ]
101
+
102
+ for pattern in colbert_patterns:
103
+ if pattern in model_name_lower:
104
+ return 'colbert'
105
+
106
+ # Default to cross-encoder for BGE and other models
107
+ return 'crossencoder'
108
+
109
+ def _similarity_search_with_colbert_embeddings(self, query: str, k: int = 5, **kwargs) -> List[Tuple[Document, float]]:
110
+ """
111
+ Perform similarity search and fetch ColBERT embeddings for documents.
112
+
113
+ Args:
114
+ query: Search query
115
+ k: Number of documents to retrieve
116
+ **kwargs: Additional search parameters (filter, etc.)
117
+
118
+ Returns:
119
+ List of (Document, score) tuples with ColBERT embeddings in metadata
120
+ """
121
+ try:
122
+ print(f"🔍 COLBERT RETRIEVAL: Fetching documents with ColBERT embeddings")
123
+
124
+ # Use the vectorstore's similarity_search_with_score method instead of direct client
125
+ # This ensures proper filter handling
126
+ if 'filter' in kwargs and kwargs['filter']:
127
+ # Use the vectorstore method with filter
128
+ result = self.vectorstore.similarity_search_with_score(
129
+ query,
130
+ k=k,
131
+ filter=kwargs['filter']
132
+ )
133
+ else:
134
+ # Use the vectorstore method without filter
135
+ result = self.vectorstore.similarity_search_with_score(query, k=k)
136
+
137
+ # Convert to the format we need
138
+ if isinstance(result, tuple) and len(result) == 2:
139
+ documents, scores = result
140
+ elif isinstance(result, list):
141
+ documents = []
142
+ scores = []
143
+ for item in result:
144
+ if isinstance(item, tuple) and len(item) == 2:
145
+ doc, score = item
146
+ documents.append(doc)
147
+ scores.append(score)
148
+ else:
149
+ documents.append(item)
150
+ scores.append(0.0)
151
+ else:
152
+ documents = []
153
+ scores = []
154
+
155
+ # Now we need to fetch the ColBERT embeddings for these documents
156
+ # We'll use the Qdrant client directly for this part since we need specific payload fields
157
+ from qdrant_client.http import models as rest
158
+
159
+ collection_name = self.vectorstore.collection_name
160
+
161
+ # Get document IDs from the retrieved documents
162
+ doc_ids = []
163
+ for doc in documents:
164
+ # Extract ID from document metadata or use page_content hash as fallback
165
+ doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
166
+ if not doc_id:
167
+ # Use a hash of the content as ID
168
+ import hashlib
169
+ doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
170
+ doc_ids.append(doc_id)
171
+
172
+ # Fetch documents with ColBERT embeddings from Qdrant
173
+ search_result = self.vectorstore.client.retrieve(
174
+ collection_name=collection_name,
175
+ ids=doc_ids,
176
+ with_payload=True,
177
+ with_vectors=False
178
+ )
179
+
180
+ # Convert results to Document objects with ColBERT embeddings
181
+ enhanced_documents = []
182
+ enhanced_scores = []
183
+
184
+ # Create a mapping from doc_id to original score
185
+ doc_id_to_score = {}
186
+ for i, doc in enumerate(documents):
187
+ doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
188
+ if not doc_id:
189
+ import hashlib
190
+ doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
191
+ doc_id_to_score[doc_id] = scores[i]
192
+
193
+ for point in search_result:
194
+ # Extract payload
195
+ payload = point.payload
196
+
197
+ # Get the original score for this document
198
+ doc_id = str(point.id)
199
+ original_score = doc_id_to_score.get(doc_id, 0.0)
200
+
201
+ # Create Document object with ColBERT embeddings
202
+ doc = Document(
203
+ page_content=payload.get('page_content', ''),
204
+ metadata={
205
+ **payload.get('metadata', {}),
206
+ 'colbert_embedding': payload.get('colbert_embedding'),
207
+ 'colbert_model': payload.get('colbert_model'),
208
+ 'colbert_calculated_at': payload.get('colbert_calculated_at')
209
+ }
210
+ )
211
+
212
+ enhanced_documents.append(doc)
213
+ enhanced_scores.append(original_score)
214
+
215
+ print(f"✅ COLBERT RETRIEVAL: Retrieved {len(enhanced_documents)} documents with ColBERT embeddings")
216
+
217
+ return list(zip(enhanced_documents, enhanced_scores))
218
+
219
+ except Exception as e:
220
+ print(f"❌ COLBERT RETRIEVAL ERROR: {e}")
221
+ print(f"❌ Falling back to regular similarity search")
222
+
223
+ # Fallback to regular search - handle filter parameter correctly
224
+ if 'filter' in kwargs and kwargs['filter']:
225
+ return self.vectorstore.similarity_search_with_score(query, k=k, filter=kwargs['filter'])
226
+ else:
227
+ return self.vectorstore.similarity_search_with_score(query, k=k)
228
+
229
+ def retrieve_context(
230
+ self,
231
+ query: str,
232
+ k: int = 5,
233
+ reports: Optional[List[str]] = None,
234
+ sources: Optional[List[str]] = None,
235
+ subtype: Optional[str] = None,
236
+ year: Optional[str] = None,
237
+ district: Optional[List[str]] = None,
238
+ filenames: Optional[List[str]] = None,
239
+ use_reranking: bool = False,
240
+ qdrant_filter: Optional[rest.Filter] = None
241
+ ) -> List[Document]:
242
+ """
243
+ Retrieve context documents using hybrid search with optional filtering and reranking.
244
+
245
+ Args:
246
+ query: User query
247
+ top_k: Number of documents to retrieve
248
+ reports: List of report names to filter by
249
+ sources: List of sources to filter by
250
+ subtype: Document subtype to filter by
251
+ year: Year to filter by
252
+ use_reranking: Whether to apply reranking
253
+ qdrant_filter: Pre-built Qdrant filter to use
254
+
255
+ Returns:
256
+ List of retrieved documents
257
+ """
258
+ try:
259
+ # Determine how many documents to retrieve
260
+ retrieve_k = k #* 3 if use_reranking else k # Retrieve more for reranking
261
+
262
+ # Build search kwargs
263
+ search_kwargs = {}
264
+
265
+ # Use qdrant_filter if provided (this takes precedence)
266
+ if qdrant_filter:
267
+ search_kwargs = {"filter": qdrant_filter}
268
+ print(f"✅ FILTERS APPLIED: Using inferred Qdrant filter")
269
+ else:
270
+ # Build filter from individual parameters
271
+ filter_obj = create_filter(
272
+ reports=reports,
273
+ sources=sources,
274
+ subtype=subtype,
275
+ year=year,
276
+ district=district,
277
+ filenames=filenames
278
+ )
279
+
280
+ if filter_obj:
281
+ search_kwargs = {"filter": filter_obj}
282
+ print(f"✅ FILTERS APPLIED: Using built filter")
283
+ else:
284
+ search_kwargs = {}
285
+ print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
286
+
287
+ # Perform vector search
288
+ try:
289
+ # Check if we need ColBERT embeddings for reranking
290
+ if use_reranking and self.reranker_type == 'colbert':
291
+ result = self._similarity_search_with_colbert_embeddings(
292
+ query,
293
+ k=retrieve_k,
294
+ **search_kwargs
295
+ )
296
+ else:
297
+ result = self.vectorstore.similarity_search_with_score(
298
+ query,
299
+ k=retrieve_k,
300
+ **search_kwargs
301
+ )
302
+
303
+ # Handle different return formats
304
+ if isinstance(result, tuple) and len(result) == 2:
305
+ documents, scores = result
306
+ elif isinstance(result, list) and len(result) > 0:
307
+ # Handle case where result is a list of (Document, score) tuples
308
+ documents = []
309
+ scores = []
310
+ for item in result:
311
+ if isinstance(item, tuple) and len(item) == 2:
312
+ doc, score = item
313
+ documents.append(doc)
314
+ scores.append(score)
315
+ else:
316
+ # Handle case where item is just a Document
317
+ documents.append(item)
318
+ scores.append(0.0) # Default score
319
+ else:
320
+ documents = []
321
+ scores = []
322
+
323
+ print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(documents)} documents (requested: {retrieve_k})")
324
+
325
+ # If we got fewer documents than requested, try without filters
326
+ if len(documents) < retrieve_k and search_kwargs.get('filter'):
327
+ print(f"⚠️ RETRIEVAL: Got {len(documents)} docs with filters, trying without filters...")
328
+ try:
329
+ result_no_filter = self.vectorstore.similarity_search_with_score(
330
+ query,
331
+ k=retrieve_k
332
+ )
333
+
334
+ if isinstance(result_no_filter, tuple) and len(result_no_filter) == 2:
335
+ documents_no_filter, scores_no_filter = result_no_filter
336
+ elif isinstance(result_no_filter, list):
337
+ documents_no_filter = []
338
+ scores_no_filter = []
339
+ for item in result_no_filter:
340
+ if isinstance(item, tuple) and len(item) == 2:
341
+ doc, score = item
342
+ documents_no_filter.append(doc)
343
+ scores_no_filter.append(score)
344
+ else:
345
+ documents_no_filter.append(item)
346
+ scores_no_filter.append(0.0)
347
+ else:
348
+ documents_no_filter = []
349
+ scores_no_filter = []
350
+
351
+ if len(documents_no_filter) > len(documents):
352
+ print(f"✅ RETRIEVAL: Got {len(documents_no_filter)} docs without filters")
353
+ documents = documents_no_filter
354
+ scores = scores_no_filter
355
+ except Exception as e:
356
+ print(f"⚠️ RETRIEVAL: Fallback search failed: {e}")
357
+
358
+ except Exception as e:
359
+ print(f"❌ RETRIEVAL ERROR: {str(e)}")
360
+ return []
361
+
362
+ # Apply reranking if enabled
363
+ reranking_applied = False
364
+ if use_reranking and len(documents) > 1:
365
+ print(f"🔄 RERANKING: Applying {self.reranker_model_name} to {len(documents)} documents...")
366
+ try:
367
+ original_docs = documents.copy()
368
+ original_scores = scores.copy()
369
+
370
+ # Apply reranking
371
+ # print(f"🔍 ORIGINAL DOCS: {documents[0]}")
372
+ reranked_docs = self._apply_reranking(query, documents, scores)
373
+ # print(f"🔍 RERANKED DOCS: {reranked_docs[0]}")
374
+ reranking_applied = len(reranked_docs) > 0
375
+
376
+ if reranking_applied:
377
+ print(f"✅ RERANKING APPLIED: {self.reranker_model_name}")
378
+ documents = reranked_docs
379
+ # Update scores to reflect reranking
380
+ # scores = [0.0] * len(documents) # Reranked scores are not directly comparable
381
+ else:
382
+ print(f"⚠️ RERANKING FAILED: Using original order")
383
+ documents = original_docs
384
+ scores = original_scores
385
+ return documents
386
+
387
+ except Exception as e:
388
+ print(f"❌ RERANKING ERROR: {str(e)}")
389
+ print(f"⚠️ RERANKING FAILED: Using original order")
390
+ reranking_applied = False
391
+ elif use_reranking and len(documents) <= 1:
392
+ print(f"ℹ️ RERANKING: Skipped (only {len(documents)} document(s) retrieved)")
393
+ if use_reranking:
394
+ print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
395
+ # Store original scores in metadata
396
+ for i, (doc, score) in enumerate(zip(documents, scores)):
397
+ doc.metadata['original_score'] = float(score)
398
+ doc.metadata['reranking_applied'] = False
399
+ return documents
400
+ else:
401
+ print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
402
+
403
+ # Limit to requested number of documents
404
+ documents = documents[:k]
405
+ scores = scores[:k] if scores else [0.0] * len(documents)
406
+
407
+ # Add metadata to documents
408
+ for i, (doc, score) in enumerate(zip(documents, scores)):
409
+ if hasattr(doc, 'metadata'):
410
+ doc.metadata.update({
411
+ 'reranking_applied': reranking_applied,
412
+ 'reranker_model': 'BAAI/bge-reranker-v2-m3' if reranking_applied else None,
413
+ 'original_rank': i + 1,
414
+ 'final_rank': i + 1,
415
+ 'original_score': float(score) if score is not None else 0.0
416
+ })
417
+
418
+ return documents
419
+
420
+ except Exception as e:
421
+ print(f"❌ CONTEXT RETRIEVAL ERROR: {str(e)}")
422
+ return []
423
+
424
+ def _apply_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
425
+ """
426
+ Apply reranking to documents using the appropriate reranker.
427
+
428
+ Args:
429
+ query: User query
430
+ documents: List of documents to rerank
431
+ scores: Original scores
432
+
433
+ Returns:
434
+ Reranked list of documents
435
+ """
436
+ if not self.reranker or len(documents) == 0:
437
+ return documents
438
+
439
+ try:
440
+ print(f"🔍 RERANKING METHOD: Starting reranking with {len(documents)} documents")
441
+ print(f"🔍 RERANKING TYPE: {self.reranker_type.upper()}")
442
+
443
+ if self.reranker_type == 'colbert':
444
+ return self._apply_colbert_reranking(query, documents, scores)
445
+ else:
446
+ return self._apply_crossencoder_reranking(query, documents, scores)
447
+
448
+ except Exception as e:
449
+ print(f"❌ RERANKING ERROR: {str(e)}")
450
+ return documents
451
+
452
+ def _apply_crossencoder_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
453
+ """
454
+ Apply reranking using CrossEncoder (BGE and other models).
455
+
456
+ Args:
457
+ query: User query
458
+ documents: List of documents to rerank
459
+ scores: Original scores
460
+
461
+ Returns:
462
+ Reranked list of documents
463
+ """
464
+ # Prepare pairs for reranking
465
+ pairs = []
466
+ for doc in documents:
467
+ pairs.append([query, doc.page_content])
468
+
469
+ print(f"🔍 CROSS-ENCODER: Prepared {len(pairs)} pairs for reranking")
470
+
471
+ # Get reranking scores using the correct CrossEncoder API
472
+ rerank_scores = self.reranker.predict(pairs)
473
+
474
+ # Handle single score case
475
+ if not isinstance(rerank_scores, (list, np.ndarray)):
476
+ rerank_scores = [rerank_scores]
477
+
478
+ # Ensure we have the right number of scores
479
+ if len(rerank_scores) != len(documents):
480
+ print(f"⚠️ RERANKING WARNING: Expected {len(documents)} scores, got {len(rerank_scores)}")
481
+ return documents
482
+
483
+ print(f"🔍 CROSS-ENCODER: Got {len(rerank_scores)} rerank scores")
484
+ print(f"🔍 CROSS-ENCODER SCORES: {rerank_scores[:5]}...") # Show first 5 scores
485
+
486
+ # Combine documents with their rerank scores
487
+ doc_scores = list(zip(documents, rerank_scores))
488
+
489
+ # Sort by rerank score (descending)
490
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
491
+
492
+ # Extract reranked documents and store scores in metadata
493
+ reranked_docs = []
494
+ for i, (doc, rerank_score) in enumerate(doc_scores):
495
+ # Find original index for original score
496
+ original_idx = documents.index(doc)
497
+ original_score = scores[original_idx] if original_idx < len(scores) else 0.0
498
+
499
+ # Create new document with reranking metadata
500
+ new_doc = Document(
501
+ page_content=doc.page_content,
502
+ metadata={
503
+ **doc.metadata,
504
+ 'reranking_applied': True,
505
+ 'reranker_model': self.reranker_model_name,
506
+ 'reranker_type': self.reranker_type,
507
+ 'original_rank': original_idx + 1,
508
+ 'final_rank': i + 1,
509
+ 'original_score': float(original_score),
510
+ 'reranked_score': float(rerank_score)
511
+ }
512
+ )
513
+ reranked_docs.append(new_doc)
514
+
515
+ print(f"✅ CROSS-ENCODER: Reranked {len(reranked_docs)} documents")
516
+
517
+ return reranked_docs
518
+
519
+ def _apply_colbert_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
520
+ """
521
+ Apply reranking using ColBERT late interaction.
522
+
523
+ Args:
524
+ query: User query
525
+ documents: List of documents to rerank
526
+ scores: Original scores
527
+
528
+ Returns:
529
+ Reranked list of documents
530
+ """
531
+ # Use the actual ColBERT reranking implementation
532
+ return self._colbert_rerank(query, documents, scores)
533
+
534
+ def _colbert_rerank(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
535
+ """
536
+ ColBERT reranking using late interaction with pre-calculated embeddings support.
537
+
538
+ Args:
539
+ query: User query
540
+ documents: List of documents to rerank
541
+ scores: Original scores
542
+
543
+ Returns:
544
+ Reranked list of documents
545
+ """
546
+ try:
547
+ print(f"🔍 COLBERT: Starting late interaction reranking with {len(documents)} documents")
548
+
549
+ # Check if documents have pre-calculated ColBERT embeddings
550
+ pre_calculated_embeddings = []
551
+ documents_without_embeddings = []
552
+ documents_without_indices = []
553
+
554
+ for i, doc in enumerate(documents):
555
+ if (hasattr(doc, 'metadata') and
556
+ 'colbert_embedding' in doc.metadata and
557
+ doc.metadata['colbert_embedding'] is not None):
558
+ # Use pre-calculated embedding
559
+ colbert_embedding = doc.metadata['colbert_embedding']
560
+ if isinstance(colbert_embedding, list):
561
+ colbert_embedding = torch.tensor(colbert_embedding)
562
+ pre_calculated_embeddings.append(colbert_embedding)
563
+ else:
564
+ # Need to calculate embedding
565
+ documents_without_embeddings.append(doc)
566
+ documents_without_indices.append(i)
567
+
568
+ # Calculate query embedding
569
+ query_embeddings = self.colbert_checkpoint.queryFromText([query])
570
+
571
+ # Calculate embeddings for documents without pre-calculated ones
572
+ if documents_without_embeddings:
573
+ print(f"🔄 COLBERT: Calculating embeddings for {len(documents_without_embeddings)} documents without pre-calculated embeddings")
574
+ doc_texts = [doc.page_content for doc in documents_without_embeddings]
575
+ doc_embeddings = self.colbert_checkpoint.docFromText(doc_texts)
576
+
577
+ # Insert calculated embeddings into the right positions
578
+ for i, embedding in enumerate(doc_embeddings):
579
+ idx = documents_without_indices[i]
580
+ pre_calculated_embeddings.insert(idx, embedding)
581
+ else:
582
+ print(f"✅ COLBERT: Using pre-calculated embeddings for all {len(documents)} documents")
583
+
584
+ # Calculate late interaction scores
585
+ # ColBERT uses MaxSim: for each query token, find max similarity with document tokens
586
+ colbert_scores = []
587
+ for i, doc_embedding in enumerate(pre_calculated_embeddings):
588
+ # Calculate similarity matrix between query and document i
589
+ sim_matrix = torch.matmul(query_embeddings[0], doc_embedding.transpose(-1, -2))
590
+
591
+ # MaxSim: for each query token, take max similarity with document
592
+ max_sim_per_query_token = torch.max(sim_matrix, dim=-1)[0]
593
+
594
+ # Sum over query tokens to get final score
595
+ final_score = torch.sum(max_sim_per_query_token).item()
596
+ colbert_scores.append(final_score)
597
+
598
+ # Sort documents by ColBERT scores
599
+ doc_scores = list(zip(documents, colbert_scores))
600
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
601
+
602
+ # Create reranked documents with metadata
603
+ reranked_docs = []
604
+ for i, (doc, colbert_score) in enumerate(doc_scores):
605
+ original_idx = documents.index(doc)
606
+ original_score = scores[original_idx] if original_idx < len(scores) else 0.0
607
+
608
+ new_doc = Document(
609
+ page_content=doc.page_content,
610
+ metadata={
611
+ **doc.metadata,
612
+ 'reranking_applied': True,
613
+ 'reranker_model': self.reranker_model_name,
614
+ 'reranker_type': self.reranker_type,
615
+ 'original_rank': original_idx + 1,
616
+ 'final_rank': i + 1,
617
+ 'original_score': float(original_score),
618
+ 'reranked_score': float(colbert_score),
619
+ 'colbert_score': float(colbert_score),
620
+ 'colbert_embedding_pre_calculated': 'colbert_embedding' in doc.metadata
621
+ }
622
+ )
623
+ reranked_docs.append(new_doc)
624
+
625
+ print(f"✅ COLBERT: Reranked {len(reranked_docs)} documents using late interaction")
626
+ print(f"🔍 COLBERT SCORES: {[f'{score:.4f}' for score in colbert_scores[:5]]}...")
627
+
628
+ return reranked_docs
629
+
630
+ except Exception as e:
631
+ print(f"❌ COLBERT RERANKING ERROR: {str(e)}")
632
+ print(f"❌ COLBERT TRACEBACK: {traceback.format_exc()}")
633
+ # Fallback to original order - return documents as-is
634
+ return documents
635
+
636
+ def retrieve_with_scores(self, query: str, vectorstore=None, k: int = 5, reports: List[str] = None,
637
+ sources: List[str] = None, subtype: List[str] = None,
638
+ year: List[str] = None, use_reranking: bool = False,
639
+ qdrant_filter: Optional[rest.Filter] = None) -> Tuple[List[Document], List[float]]:
640
+ """
641
+ Retrieve context documents with scores using hybrid search with optional reranking.
642
+
643
+ Args:
644
+ query: User query
645
+ vectorstore: Optional vectorstore instance (for compatibility)
646
+ k: Number of documents to retrieve
647
+ reports: List of report names to filter by
648
+ sources: List of sources to filter by
649
+ subtype: Document subtype to filter by
650
+ year: List of years to filter by
651
+ use_reranking: Whether to apply reranking
652
+ qdrant_filter: Pre-built Qdrant filter
653
+
654
+ Returns:
655
+ Tuple of (documents, scores)
656
+ """
657
+ try:
658
+ # Use the provided vectorstore if available, otherwise use the instance one
659
+ if vectorstore:
660
+ self.vectorstore = vectorstore
661
+
662
+ # Determine search strategy
663
+ search_strategy = self.config.get('retrieval', {}).get('search_strategy', 'vector_only')
664
+
665
+ if search_strategy == 'vector_only':
666
+ # Vector search only
667
+ print(f"🔄 VECTOR SEARCH: Retrieving {k} documents...")
668
+
669
+ if qdrant_filter:
670
+ print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
671
+ # Pass filter as positional argument, not keyword argument
672
+ results = self.vectorstore.similarity_search_with_score(
673
+ query,
674
+ k=k,
675
+ filter=qdrant_filter
676
+ )
677
+ else:
678
+ # Build filter from individual parameters
679
+ filter_conditions = self._build_filter_conditions(reports, sources, subtype, year)
680
+ if filter_conditions:
681
+ print(f"✅ FILTER APPLIED: {filter_conditions}")
682
+ results = self.vectorstore.similarity_search_with_score(
683
+ query,
684
+ k=k,
685
+ filter=filter_conditions
686
+ )
687
+ else:
688
+ print(f"ℹ️ NO FILTERS APPLIED: All documents will be searched")
689
+ results = self.vectorstore.similarity_search_with_score(query, k=k)
690
+
691
+ print(f"🔍 SEARCH DEBUG: Raw result type: {type(results)}")
692
+ print(f"🔍 SEARCH DEBUG: Raw result length: {len(results)}")
693
+
694
+ # Handle different result formats
695
+ if results and isinstance(results[0], tuple):
696
+ documents = [doc for doc, score in results]
697
+ scores = [score for doc, score in results]
698
+ print(f"🔍 SEARCH DEBUG: After unpacking - documents: {len(documents)}, scores: {len(scores)}")
699
+ else:
700
+ documents = results
701
+ scores = [0.0] * len(documents)
702
+ print(f"🔍 SEARCH DEBUG: No scores available, using default")
703
+
704
+ print(f"🔧 CONVERTING: Converting {len(documents)} documents")
705
+
706
+ # Convert to Document objects and store original scores
707
+ final_documents = []
708
+ for i, (doc, score) in enumerate(zip(documents, scores)):
709
+ if hasattr(doc, 'page_content'):
710
+ new_doc = Document(
711
+ page_content=doc.page_content,
712
+ metadata=doc.metadata.copy()
713
+ )
714
+ # Store original score in metadata
715
+ new_doc.metadata['original_score'] = float(score) if score is not None else 0.0
716
+ final_documents.append(new_doc)
717
+ else:
718
+ print(f"⚠️ WARNING: Document {i} has no page_content")
719
+
720
+ print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(final_documents)} documents")
721
+
722
+ # Apply reranking if enabled
723
+ if use_reranking and len(final_documents) > 1:
724
+ print(f"🔄 RERANKING: Applying {self.reranker_model} to {len(final_documents)} documents...")
725
+ final_documents = self._apply_reranking(query, final_documents, scores)
726
+ print(f"✅ RERANKING APPLIED: {self.reranker_model}")
727
+ else:
728
+ print(f"ℹ️ RERANKING: Skipped (disabled or no documents)")
729
+
730
+ return final_documents, scores
731
+
732
+ else:
733
+ print(f"❌ UNSUPPORTED STRATEGY: {search_strategy}")
734
+ return [], []
735
+
736
+ except Exception as e:
737
+ print(f"❌ RETRIEVAL ERROR: {e}")
738
+ print(f"❌ RETRIEVAL TRACEBACK: {traceback.format_exc()}")
739
+ return [], []
740
+
741
+ def _build_filter_conditions(self, reports: List[str] = None, sources: List[str] = None,
742
+ subtype: List[str] = None, year: List[str] = None) -> Optional[rest.Filter]:
743
+ """
744
+ Build Qdrant filter conditions from individual parameters.
745
+
746
+ Args:
747
+ reports: List of report names
748
+ sources: List of sources
749
+ subtype: Document subtype
750
+ year: List of years
751
+
752
+ Returns:
753
+ Qdrant filter or None
754
+ """
755
+ conditions = []
756
+
757
+ if reports:
758
+ conditions.append(rest.FieldCondition(
759
+ key="metadata.filename",
760
+ match=rest.MatchAny(any=reports)
761
+ ))
762
+
763
+ if sources:
764
+ conditions.append(rest.FieldCondition(
765
+ key="metadata.source",
766
+ match=rest.MatchAny(any=sources)
767
+ ))
768
+
769
+ if subtype:
770
+ conditions.append(rest.FieldCondition(
771
+ key="metadata.subtype",
772
+ match=rest.MatchAny(any=subtype)
773
+ ))
774
+
775
+ if year:
776
+ conditions.append(rest.FieldCondition(
777
+ key="metadata.year",
778
+ match=rest.MatchAny(any=year)
779
+ ))
780
+
781
+ if conditions:
782
+ return rest.Filter(must=conditions)
783
+
784
+ return None
785
+
786
+ def get_context(
787
+ query: str,
788
+ vectorstore: Qdrant,
789
+ k: int = 5,
790
+ reports: Optional[List[str]] = None,
791
+ sources: Optional[List[str]] = None,
792
+ subtype: Optional[str] = None,
793
+ year: Optional[str] = None,
794
+ use_reranking: bool = False,
795
+ qdrant_filter: Optional[rest.Filter] = None
796
+ ) -> List[Document]:
797
+ """
798
+ Convenience function to get context documents.
799
+
800
+ Args:
801
+ query: User query
802
+ vectorstore: Qdrant vector store instance
803
+ k: Number of documents to retrieve
804
+ reports: Optional list of report names to filter by
805
+ sources: Optional list of source categories to filter by
806
+ subtype: Optional subtype to filter by
807
+ year: Optional year to filter by
808
+ use_reranking: Whether to apply reranking
809
+ qdrant_filter: Optional pre-built Qdrant filter
810
+
811
+ Returns:
812
+ List of retrieved documents
813
+ """
814
+ retriever = ContextRetriever(vectorstore)
815
+ return retriever.retrieve_context(
816
+ query=query,
817
+ k=k,
818
+ reports=reports,
819
+ sources=sources,
820
+ subtype=subtype,
821
+ year=year,
822
+ use_reranking=use_reranking,
823
+ qdrant_filter=qdrant_filter
824
+ )
825
+
826
+
827
+ def format_context_for_llm(documents: List[Document]) -> str:
828
+ """
829
+ Format retrieved documents for LLM input.
830
+
831
+ Args:
832
+ documents: List of Document objects
833
+
834
+ Returns:
835
+ Formatted string for LLM
836
+ """
837
+ if not documents:
838
+ return ""
839
+
840
+ formatted_parts = []
841
+ for i, doc in enumerate(documents, 1):
842
+ content = doc.page_content.strip()
843
+ source = doc.metadata.get('filename', 'Unknown')
844
+
845
+ formatted_parts.append(f"Document {i} (Source: {source}):\n{content}")
846
+
847
+ return "\n\n".join(formatted_parts)
848
+
849
+
850
+ def get_context_metadata(documents: List[Document]) -> Dict[str, Any]:
851
+ """
852
+ Extract metadata summary from retrieved documents.
853
+
854
+ Args:
855
+ documents: List of Document objects
856
+
857
+ Returns:
858
+ Dictionary with metadata summary
859
+ """
860
+ if not documents:
861
+ return {}
862
+
863
+ sources = set()
864
+ years = set()
865
+ doc_types = set()
866
+
867
+ for doc in documents:
868
+ metadata = doc.metadata
869
+ if 'filename' in metadata:
870
+ sources.add(metadata['filename'])
871
+ if 'year' in metadata:
872
+ years.add(metadata['year'])
873
+ if 'source' in metadata:
874
+ doc_types.add(metadata['source'])
875
+
876
+ return {
877
+ "num_documents": len(documents),
878
+ "sources": list(sources),
879
+ "years": list(years),
880
+ "document_types": list(doc_types)
881
+ }
src/retrieval/filter.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document filtering utilities for Qdrant vector store."""
2
+
3
+ from typing import List, Optional, Union, Dict, Tuple, Any
4
+ from qdrant_client.http import models as rest
5
+ import time
6
+
7
+
8
+ class FilterBuilder:
9
+ """Builder class for creating Qdrant filters."""
10
+
11
+ def __init__(self):
12
+ self.conditions = []
13
+
14
+ def add_source_filter(self, source: Union[str, List[str]]) -> 'FilterBuilder':
15
+ """Add source filter condition."""
16
+ if source:
17
+ if isinstance(source, list):
18
+ condition = rest.FieldCondition(
19
+ key="metadata.source",
20
+ match=rest.MatchAny(any=source)
21
+ )
22
+ print(f"🔧 FilterBuilder: Added source filter for {source}")
23
+ else:
24
+ condition = rest.FieldCondition(
25
+ key="metadata.source",
26
+ match=rest.MatchValue(value=source)
27
+ )
28
+ print(f"🔧 FilterBuilder: Added source filter for '{source}'")
29
+ self.conditions.append(condition)
30
+ return self
31
+
32
+ def add_filename_filter(self, filenames: List[str]) -> 'FilterBuilder':
33
+ """Add filename filter condition."""
34
+ if filenames:
35
+ condition = rest.FieldCondition(
36
+ key="metadata.filename",
37
+ match=rest.MatchAny(any=filenames)
38
+ )
39
+ self.conditions.append(condition)
40
+ print(f"🔧 FilterBuilder: Added filename filter for {filenames}")
41
+ return self
42
+
43
+ def add_year_filter(self, years: List[str]) -> 'FilterBuilder':
44
+ """Add year filter condition."""
45
+ if years:
46
+ condition = rest.FieldCondition(
47
+ key="metadata.year",
48
+ match=rest.MatchAny(any=years)
49
+ )
50
+ self.conditions.append(condition)
51
+ print(f"🔧 FilterBuilder: Added year filter for {years}")
52
+ return self
53
+
54
+ def add_district_filter(self, districts: List[str]) -> 'FilterBuilder':
55
+ """Add district filter condition."""
56
+ if districts:
57
+ condition = rest.FieldCondition(
58
+ key="metadata.district",
59
+ match=rest.MatchAny(any=districts)
60
+ )
61
+ self.conditions.append(condition)
62
+ print(f"🔧 FilterBuilder: Added district filter for {districts}")
63
+ return self
64
+
65
+ def add_custom_filter(self, key: str, value: Union[str, List[str]]) -> 'FilterBuilder':
66
+ """Add custom filter condition."""
67
+ if isinstance(value, list):
68
+ condition = rest.FieldCondition(
69
+ key=key,
70
+ match=rest.MatchAny(any=value)
71
+ )
72
+ else:
73
+ condition = rest.FieldCondition(
74
+ key=key,
75
+ match=rest.MatchValue(value=value)
76
+ )
77
+ self.conditions.append(condition)
78
+ return self
79
+
80
+ def build(self) -> rest.Filter:
81
+ """Build the final filter."""
82
+ if not self.conditions:
83
+ return None
84
+
85
+ return rest.Filter(must=self.conditions)
86
+
87
+
88
+ def create_filter(
89
+ reports: List[str] = None,
90
+ sources: Union[str, List[str]] = None,
91
+ subtype: List[str] = None,
92
+ year: List[str] = None,
93
+ district: List[str] = None,
94
+ filenames: List[str] = None
95
+ ) -> rest.Filter:
96
+ """
97
+ Create a search filter for Qdrant (legacy function for compatibility).
98
+
99
+ Args:
100
+ reports: List of specific report filenames
101
+ sources: Source category
102
+ subtype: List of subtypes/filenames
103
+ year: List of years
104
+ district: List of districts
105
+ filenames: List of specific filenames (mutually exclusive with other filters)
106
+
107
+ Returns:
108
+ Qdrant Filter object
109
+
110
+ Note:
111
+ If filenames are provided, ONLY filename filtering is applied (mutually exclusive)
112
+ """
113
+ builder = FilterBuilder()
114
+
115
+ # Check if filename filtering is requested (mutually exclusive)
116
+ # Both filenames and reports serve the same purpose (backward compatibility)
117
+ # Prefer filenames, fallback to reports for legacy support
118
+ target_filenames = filenames if filenames else reports
119
+
120
+ if target_filenames and len(target_filenames) > 0:
121
+ # ONLY apply filename filter, ignore all other filters
122
+ print(f"🔍 FILTER APPLIED: Filenames = {target_filenames} (mutually exclusive mode)")
123
+ builder.add_filename_filter(target_filenames)
124
+ else:
125
+ # Otherwise, filter by source and subtype
126
+ print(f"🔍 FILTER APPLIED: Sources = {sources}, Subtype = {subtype}, Year = {year}, District = {district}")
127
+ if sources:
128
+ print(f"✅ Adding source filter: metadata.source = '{sources}'")
129
+ builder.add_source_filter(sources)
130
+ if subtype:
131
+ print(f"✅ Adding subtype filter: metadata.filename IN {subtype}")
132
+ builder.add_filename_filter(subtype)
133
+ if year:
134
+ print(f"✅ Adding year filter: metadata.year IN {year}")
135
+ builder.add_year_filter(year)
136
+
137
+ if district:
138
+ print(f"✅ Adding district filter: metadata.district IN {district}")
139
+ builder.add_district_filter(district)
140
+
141
+ filter_obj = builder.build()
142
+
143
+ if filter_obj:
144
+ print(f"�� FINAL FILTER: {len(filter_obj.must)} condition(s) applied")
145
+ for i, condition in enumerate(filter_obj.must, 1):
146
+ print(f" Condition {i}: {condition.key} = {condition.match}")
147
+ else:
148
+ print("⚠️ NO FILTERS APPLIED: All documents will be searched")
149
+
150
+ return filter_obj
151
+
152
+
153
+ def create_advanced_filter(
154
+ must_conditions: List[dict] = None,
155
+ should_conditions: List[dict] = None,
156
+ must_not_conditions: List[dict] = None
157
+ ) -> rest.Filter:
158
+ """
159
+ Create advanced filter with multiple condition types.
160
+
161
+ Args:
162
+ must_conditions: Conditions that must match
163
+ should_conditions: Conditions that should match (OR logic)
164
+ must_not_conditions: Conditions that must not match
165
+
166
+ Returns:
167
+ Qdrant Filter object
168
+ """
169
+ filter_dict = {}
170
+
171
+ if must_conditions:
172
+ filter_dict["must"] = [
173
+ _dict_to_field_condition(cond) for cond in must_conditions
174
+ ]
175
+
176
+ if should_conditions:
177
+ filter_dict["should"] = [
178
+ _dict_to_field_condition(cond) for cond in should_conditions
179
+ ]
180
+
181
+ if must_not_conditions:
182
+ filter_dict["must_not"] = [
183
+ _dict_to_field_condition(cond) for cond in must_not_conditions
184
+ ]
185
+
186
+ if not filter_dict:
187
+ return None
188
+
189
+ return rest.Filter(**filter_dict)
190
+
191
+
192
+ def _dict_to_field_condition(condition_dict: dict) -> rest.FieldCondition:
193
+ """Convert dictionary to FieldCondition."""
194
+ key = condition_dict["key"]
195
+ value = condition_dict["value"]
196
+
197
+ if isinstance(value, list):
198
+ match = rest.MatchAny(any=value)
199
+ else:
200
+ match = rest.MatchValue(value=value)
201
+
202
+ return rest.FieldCondition(key=key, match=match)
203
+
204
+
205
+ def validate_filter(filter_obj: rest.Filter) -> bool:
206
+ """
207
+ Validate that a filter object is properly constructed.
208
+
209
+ Args:
210
+ filter_obj: Qdrant Filter object
211
+
212
+ Returns:
213
+ True if valid, raises ValueError if invalid
214
+ """
215
+ if filter_obj is None:
216
+ return True
217
+
218
+ if not isinstance(filter_obj, rest.Filter):
219
+ raise ValueError("Filter must be a rest.Filter object")
220
+
221
+ # Check that at least one condition type is present
222
+ has_conditions = any([
223
+ hasattr(filter_obj, 'must') and filter_obj.must,
224
+ hasattr(filter_obj, 'should') and filter_obj.should,
225
+ hasattr(filter_obj, 'must_not') and filter_obj.must_not
226
+ ])
227
+
228
+ if not has_conditions:
229
+ raise ValueError("Filter must have at least one condition")
230
+
231
+ return True
232
+
233
+
234
+ def infer_filters_from_query(
235
+ query: str,
236
+ available_metadata: dict,
237
+ llm_client=None
238
+ ) -> Tuple[rest.Filter, Union[dict, None]]:
239
+ """
240
+ Automatically infer filters from a query using LLM analysis.
241
+
242
+ Args:
243
+ query: User query to analyze
244
+ available_metadata: Available metadata values in the vectorstore
245
+ llm_client: LLM client for analysis (optional)
246
+
247
+ Returns:
248
+ Qdrant Filter object with inferred conditions
249
+ """
250
+ print(f"�� AUTO-INFERRING FILTERS from query: '{query[:50]}...'")
251
+
252
+ # Check if LLM client is available
253
+ if not llm_client:
254
+ print(f"❌ LLM CLIENT MISSING: Cannot use LLM analysis, falling back to rule-based")
255
+ return _infer_filters_rule_based(query, available_metadata), None
256
+
257
+ # Extract available options
258
+ available_sources = available_metadata.get('sources', [])
259
+ available_years = available_metadata.get('years', [])
260
+ available_filenames = available_metadata.get('filenames', [])
261
+
262
+ print(f"📊 Available metadata: sources={len(available_sources)}, years={len(available_years)}, filenames={len(available_filenames)}")
263
+
264
+ # Try LLM analysis first
265
+ print(f" LLM ANALYSIS: Attempting LLM-based filter inference...")
266
+ llm_result = _analyze_query_with_llm(
267
+ query=query,
268
+ available_metadata=available_metadata,
269
+ llm_client=llm_client
270
+ )
271
+
272
+ if llm_result:
273
+ print(f"✅ LLM SUCCESS: LLM successfully inferred filters")
274
+ # Use the _build_qdrant_filter function to properly build the Qdrant filter
275
+ qdrant_filter, filter_summary = _build_qdrant_filter(llm_result)
276
+ if qdrant_filter:
277
+ print(f"✅ QDRANT FILTER: Successfully built Qdrant filter")
278
+ # print(f"✅ INFERRED FILTERS: {qdrant_filter}")
279
+ return qdrant_filter, filter_summary
280
+ else:
281
+ print(f"❌ QDRANT FILTER: Failed to build Qdrant filter, trying rule-based fallback")
282
+ rule_based_result = _infer_filters_rule_based(query, available_metadata)
283
+ # Use the _build_qdrant_filter function to properly build the Qdrant filter
284
+ qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
285
+ if qdrant_filter:
286
+ print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
287
+ return qdrant_filter, filter_summary
288
+ else:
289
+ print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
290
+ return None, None
291
+ else:
292
+ print(f"⚠️ LLM FAILED: LLM could not infer filters, trying rule-based fallback")
293
+ rule_based_result = _infer_filters_rule_based(query, available_metadata)
294
+ # Use the _build_qdrant_filter function to properly build the Qdrant filter
295
+ qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
296
+ if qdrant_filter:
297
+ print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
298
+ return qdrant_filter, filter_summary
299
+ else:
300
+ print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
301
+ return None, None
302
+
303
+
304
+ def _analyze_query_with_llm(
305
+ query: str,
306
+ available_metadata: Dict[str, List[str]],
307
+ llm_client=None
308
+ ) -> dict:
309
+
310
+
311
+ """
312
+ - Filenames: {available_metadata.get('filenames', [])}
313
+
314
+ 📁 FILENAME FILTERING (Use Sparingly):
315
+ - Only if specific filename explicitly mentioned
316
+ - Prefer source/subtype over filename
317
+ - Be very conservative
318
+
319
+
320
+ "filenames": ["filename1", "filename2"] or [],
321
+ - For filenames: Only use if you have high confidence and can identify specific files
322
+ """
323
+
324
+
325
+ """
326
+ Use LLM to analyze query and infer appropriate filters.
327
+
328
+ Args:
329
+ query: User query to analyze
330
+ available_metadata: Available metadata values in the vectorstore
331
+ llm_client: LLM client for analysis
332
+
333
+ Returns:
334
+ Dictionary with inferred filters or empty dict if failed
335
+ """
336
+ if not llm_client:
337
+ print("❌ LLM CLIENT MISSING: Cannot analyze query without LLM client")
338
+ return {}
339
+
340
+ try:
341
+ print(f" LLM ANALYSIS: Analyzing query with LLM...")
342
+
343
+
344
+ """
345
+ For example: "What is the expected ... in 2024" - this refference to a future statement, so retrieving documents for 2023, 2022 and 2021 can be relevant too
346
+ Another example: "What is the GDP increase now compared to 2022" - this is a relative statement, refferring to past data, so both Year 2022, and now - 2025 needs to be detected/marked
347
+ """
348
+
349
+ # Create prompt for LLM analysis
350
+ prompt = f"""
351
+ You are a filter inference system. Analyze this query and return ONLY a JSON object.
352
+
353
+ Query: "{query}"
354
+
355
+ Available metadata:
356
+ - Sources: {available_metadata.get('sources', [])}
357
+ - Years: {available_metadata.get('years', [])}
358
+
359
+ FILTER INFERENCE GUIDELINES:
360
+
361
+ YEAR FILTERING (Be VERY Conservative):
362
+ ✅ INFER YEARS ONLY IF:
363
+ - Explicit 4-digit years: "2022", "2023", "2021"
364
+ - Clear relative terms: "last year", "this year", "recent", "current year" (for the context, now is 2025)
365
+ - Temporal context: "annual report 2022", "audit for 2023"
366
+ - Give multiple years for complex queries.
367
+
368
+
369
+ ❌ DO NOT INFER YEARS FOR:
370
+ - Vague terms: "implementation", "activities", "costs", "challenges", "issues"
371
+ - General concepts: "PDM", "administrative", "budget", "staff"
372
+ - Process descriptions: "how were", "what challenges", "management of"
373
+
374
+ 🏛️ SOURCE FILTERING (Context-Based):
375
+ - "Ministry, Department and Agency" → Central government, ministries, departments, PS/ST
376
+ - "Local Government" → Districts, municipalities, local authorities, DLG
377
+ - "Consolidated" → Annual consolidated reports, OAG reports
378
+ - "Thematic" → Special studies, thematic reports
379
+
380
+ �� SUBTYPE FILTERING (Document Type):
381
+ - "audit" → Audit reports, reviews, examinations
382
+ - "report" → General reports, annual reports
383
+ - "guidance" → Guidelines, directives, circulars
384
+
385
+ CONFIDENCE SCORING:
386
+ - 0.9-1.0: Crystal clear indicators (explicit years, specific sources)
387
+ - 0.7-0.8: Good indicators (relative years, clear context)
388
+ - 0.5-0.6: Moderate indicators (some context clues)
389
+ - 0.0-0.4: Low confidence (vague or unclear)
390
+
391
+ EXAMPLES:
392
+ ✅ "What challenges arose in 2022?" → years: ["2022"], confidence: 1
393
+ ✅ "How were administrative costs managed in our government?" → sources: ["Local Government"], confidence: 0.75
394
+ ✅ "PDM implementation guidelines from last year" → years: ["2024"], confidence: 0.9
395
+ ❌ "What issues arose with budget execution?" → NO FILTERS, confidence: 0.2
396
+ ❌ "How were tools related to administrative costs?" → NO FILTERS, confidence: 0.1
397
+
398
+ RESPONSE FORMAT (JSON only):
399
+ {{
400
+ "years": ["2022", "2023"] or [],
401
+ "sources": ["Ministry, Department and Agency", "Local Government"] or [],
402
+ "subtype": ["audit", "report"] or [],
403
+ "confidence": 0.8,
404
+ "reasoning": "Very brief explanation of filter choices"
405
+ }}
406
+
407
+ Rules:
408
+ - Use OR logic (SHOULD) for multiple values
409
+ - Prefer sources over filenames
410
+ - Only include years if clearly mentioned
411
+ - Return null for unclear fields
412
+ - For sources/subtypes: Include at least 3 candidates unless confidence is high and you can identify exactly one source (MUST)
413
+ - For years: If you want to include, then include at least 2 candidates unless confidence is high and you can identify exactly one year (MUST)
414
+ """
415
+
416
+ print(f"🔄 LLM CALL: Sending prompt to LLM...")
417
+ try:
418
+ # Try different methods to call the LLM
419
+ if hasattr(llm_client, 'invoke'):
420
+ response = llm_client.invoke(prompt)
421
+ elif hasattr(llm_client, 'generate'):
422
+ response = llm_client.generate([{"role": "user", "content": prompt}])
423
+ elif hasattr(llm_client, 'call'):
424
+ response = llm_client.call(prompt)
425
+ elif hasattr(llm_client, 'predict'):
426
+ response = llm_client.predict(prompt)
427
+ else:
428
+ # Try to call it directly
429
+ response = llm_client(prompt)
430
+
431
+ print(f"✅ LLM CALL SUCCESS: Received response from LLM")
432
+
433
+ # Extract content from response
434
+ if hasattr(response, 'content'):
435
+ response_content = response.content
436
+ elif hasattr(response, 'text'):
437
+ response_content = response.text
438
+ elif isinstance(response, str):
439
+ response_content = response
440
+ else:
441
+ response_content = str(response)
442
+
443
+ print(f"🔄 LLM RESPONSE: {response_content[:200]}...")
444
+
445
+ except Exception as e:
446
+ print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
447
+ return {}
448
+
449
+ # Parse JSON response
450
+ import json
451
+ import re
452
+ try:
453
+ print(f"🔄 JSON PARSING: Attempting to parse LLM response...")
454
+
455
+ # Clean the response to extract JSON from markdown
456
+ response_text = response_content.strip()
457
+
458
+ # Remove markdown formatting if present
459
+ if "```json" in response_text:
460
+ # Extract JSON from markdown code block
461
+ start_marker = "```json"
462
+ end_marker = "```"
463
+ start_idx = response_text.find(start_marker)
464
+ if start_idx != -1:
465
+ start_idx += len(start_marker)
466
+ end_idx = response_text.find(end_marker, start_idx)
467
+ if end_idx != -1:
468
+ response_text = response_text[start_idx:end_idx].strip()
469
+
470
+ # Try to find JSON object in the response
471
+ json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
472
+ if json_match:
473
+ response_text = json_match.group(0)
474
+
475
+ print(f"🔄 JSON PARSING: Cleaned response: {response_text[:200]}...")
476
+
477
+ # Parse JSON
478
+ filters = json.loads(response_text)
479
+ print(f"✅ JSON PARSING SUCCESS: Parsed filters: {filters}")
480
+
481
+ # Validate filters
482
+ if not isinstance(filters, dict):
483
+ print(f"❌ JSON VALIDATION FAILED: Response is not a dictionary")
484
+ return {}
485
+
486
+ # Check if any filters were inferred
487
+ has_filters = any(filters.get(key) for key in ['sources', 'years', 'filenames'])
488
+ if not has_filters:
489
+ print(f"⚠️ QUERY DIFFICULT: LLM could not determine appropriate filters from query")
490
+ return {}
491
+
492
+ # print(f"✅ FILTER INFERENCE SUCCESS: Inferred filters: {filters}")
493
+ return filters
494
+
495
+ except json.JSONDecodeError as e:
496
+ print(f"❌ JSON PARSING FAILED: Invalid JSON format - {e}")
497
+ print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
498
+ return {}
499
+ except Exception as e:
500
+ print(f"❌ JSON PARSING FAILED: Unexpected error - {e}")
501
+ print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
502
+ return {}
503
+
504
+ except Exception as e:
505
+ print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
506
+ return {}
507
+
508
+
509
+ def _infer_filters_rule_based(
510
+ query: str,
511
+ available_metadata: dict
512
+ ) -> dict:
513
+ """
514
+ Rule-based fallback for filter inference with improved logic.
515
+
516
+ Args:
517
+ query: User query
518
+ available_metadata: Available metadata values in the vectorstore
519
+
520
+ Returns:
521
+ Dictionary of inferred filters
522
+ """
523
+ print(f" RULE-BASED ANALYSIS: Starting rule-based inference for query: '{query[:50]}...'")
524
+
525
+ inferred = {}
526
+ query_lower = query.lower()
527
+
528
+ # SEMANTIC SOURCE INFERENCE - Use semantic understanding
529
+ source_matches = []
530
+
531
+ # Define semantic mappings for better source inference
532
+ source_keywords = {
533
+ 'consolidated': ['consolidated', 'annual', 'oag', 'auditor general', 'government', 'financial statements', 'budget', 'expenditure', 'revenue'],
534
+ 'military': ['military', 'defence', 'defense', 'army', 'navy', 'air force', 'security', 'defense ministry'],
535
+ 'departmental': ['department', 'ministry', 'agency', 'authority', 'commission', 'board', 'directorate'],
536
+ 'thematic': ['thematic', 'sector', 'program', 'project', 'initiative', 'development', 'infrastructure']
537
+ }
538
+
539
+ for source in available_metadata.get('sources', []):
540
+ source_lower = source.lower()
541
+
542
+ # Direct keyword match
543
+ if source_lower in query_lower:
544
+ source_matches.append(source)
545
+ print(f"✅ DIRECT MATCH: Found direct keyword match for '{source}'")
546
+ else:
547
+ # Semantic keyword matching
548
+ if source_lower in source_keywords:
549
+ keywords = source_keywords[source_lower]
550
+ matches = sum(1 for keyword in keywords if keyword in query_lower)
551
+ if matches >= 2: # Require at least 2 keyword matches for semantic inference
552
+ source_matches.append(source)
553
+ print(f"✅ SEMANTIC MATCH: Found {matches} semantic keywords for '{source}': {[k for k in keywords if k in query_lower]}")
554
+
555
+ if source_matches:
556
+ # Use SHOULD (OR logic) for multiple sources
557
+ inferred['sources_should'] = source_matches
558
+ print(f"✅ SOURCE INFERENCE: Found {len(source_matches)} sources with OR logic: {source_matches}")
559
+ else:
560
+ print("❌ SOURCE INFERENCE: No source keywords found in query")
561
+
562
+ # Infer year filters - use SHOULD (OR logic) for multiple years
563
+ import re
564
+ year_matches = []
565
+ for year in available_metadata.get('years', []):
566
+ if year in query or f"'{year}" in query:
567
+ year_matches.append(year)
568
+
569
+ if year_matches:
570
+ # Use SHOULD (OR logic) for multiple years
571
+ inferred['years_should'] = year_matches
572
+ print(f"✅ YEAR INFERENCE: Found {len(year_matches)} years with OR logic: {year_matches}")
573
+ else:
574
+ print("❌ YEAR INFERENCE: No year references found in query")
575
+
576
+ # Only infer filename filters if no year filter was found (to avoid conflicts)
577
+ if not year_matches:
578
+ filename_matches = []
579
+ for filename in available_metadata.get('filenames', []):
580
+ # Only match if multiple words from filename appear in query
581
+ filename_words = filename.lower().split()
582
+ matches = sum(1 for word in filename_words if word in query_lower)
583
+ if matches >= 2: # High confidence threshold
584
+ filename_matches.append(filename)
585
+
586
+ if filename_matches:
587
+ # Use SHOULD (OR logic) for multiple filenames
588
+ inferred['filenames_should'] = filename_matches
589
+ print(f"✅ FILENAME INFERENCE: Found {len(filename_matches)} filenames with OR logic: {filename_matches}")
590
+ else:
591
+ print("❌ FILENAME INFERENCE: No high-confidence filename matches found")
592
+ else:
593
+ print("ℹ️ FILENAME INFERENCE: Skipped (year filter already applied to avoid conflicts)")
594
+
595
+ print(f" RULE-BASED RESULT: {inferred}")
596
+ return inferred
597
+
598
+
599
+ def _validate_inferred_filters(inferred_filters: dict) -> dict:
600
+ """
601
+ Validate and normalize inferred filters to ensure they're in the expected format.
602
+
603
+ Args:
604
+ inferred_filters: Raw inferred filters dictionary
605
+
606
+ Returns:
607
+ Validated and normalized filters dictionary
608
+ """
609
+ if not isinstance(inferred_filters, dict):
610
+ print(f"⚠️ FILTER VALIDATION: Inferred filters is not a dict: {type(inferred_filters)}")
611
+ return {}
612
+
613
+ validated = {}
614
+
615
+ # Normalize field names and validate values
616
+ for field_name in ['sources', 'sources_should', 'years', 'years_should', 'filenames', 'filenames_should']:
617
+ if field_name in inferred_filters and inferred_filters[field_name]:
618
+ value = inferred_filters[field_name]
619
+ if isinstance(value, list) and len(value) > 0:
620
+ # Remove any None or empty string values
621
+ clean_value = [v for v in value if v is not None and str(v).strip()]
622
+ if clean_value:
623
+ validated[field_name] = clean_value
624
+ print(f"✅ FILTER VALIDATION: {field_name} = {clean_value}")
625
+ elif isinstance(value, str) and value.strip():
626
+ validated[field_name] = [value.strip()]
627
+ print(f"✅ FILTER VALIDATION: {field_name} = [{value.strip()}]")
628
+
629
+ return validated
630
+
631
+
632
+ def _build_qdrant_filter(inferred_filters: dict) -> rest.Filter:
633
+ """
634
+ Build Qdrant filter from inferred filters.
635
+
636
+ Args:
637
+ inferred_filters: Dictionary with inferred filter values
638
+
639
+ Returns:
640
+ Qdrant Filter object
641
+ """
642
+ try:
643
+ from qdrant_client.http import models as rest
644
+
645
+ # Validate and normalize the inferred filters first
646
+ validated_filters = _validate_inferred_filters(inferred_filters)
647
+ if not validated_filters:
648
+ print(f"⚠️ NO VALID FILTERS: All filters were invalid or empty")
649
+ return None, {}
650
+
651
+ conditions = []
652
+ filter_summary = {}
653
+
654
+ # Handle sources (use OR logic for multiple values)
655
+ # Support both 'sources' and 'sources_should' field names
656
+ source_values = None
657
+ if 'sources' in validated_filters and validated_filters['sources']:
658
+ source_values = validated_filters['sources']
659
+ elif 'sources_should' in validated_filters and validated_filters['sources_should']:
660
+ source_values = validated_filters['sources_should']
661
+
662
+ if source_values and isinstance(source_values, list) and len(source_values) > 0:
663
+ if len(source_values) == 1:
664
+ conditions.append(rest.FieldCondition(
665
+ key="metadata.source",
666
+ match=rest.MatchValue(value=source_values[0])
667
+ ))
668
+ else:
669
+ # Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
670
+ conditions.append(rest.FieldCondition(
671
+ key="metadata.source",
672
+ match=rest.MatchAny(any=source_values)
673
+ ))
674
+ filter_summary['sources'] = f"SHOULD: {source_values}"
675
+
676
+ # Handle years (use OR logic for multiple values)
677
+ # Support both 'years' and 'years_should' field names
678
+ year_values = None
679
+ if 'years' in validated_filters and validated_filters['years']:
680
+ year_values = validated_filters['years']
681
+ elif 'years_should' in validated_filters and validated_filters['years_should']:
682
+ year_values = validated_filters['years_should']
683
+
684
+ if year_values and isinstance(year_values, list) and len(year_values) > 0:
685
+ if len(year_values) == 1:
686
+ conditions.append(rest.FieldCondition(
687
+ key="metadata.year",
688
+ match=rest.MatchValue(value=year_values[0])
689
+ ))
690
+ else:
691
+ # Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
692
+ conditions.append(rest.FieldCondition(
693
+ key="metadata.year",
694
+ match=rest.MatchAny(any=year_values)
695
+ ))
696
+ filter_summary['years'] = f"SHOULD: {year_values}"
697
+
698
+ # Handle filenames (use OR logic for multiple values)
699
+ # Support both 'filenames' and 'filenames_should' field names
700
+ filename_values = None
701
+ if 'filenames' in validated_filters and validated_filters['filenames']:
702
+ filename_values = validated_filters['filenames']
703
+ elif 'filenames_should' in validated_filters and validated_filters['filenames_should']:
704
+ filename_values = validated_filters['filenames_should']
705
+
706
+ if filename_values and isinstance(filename_values, list) and len(filename_values) > 0:
707
+ if len(filename_values) == 1:
708
+ conditions.append(rest.FieldCondition(
709
+ key="metadata.filename",
710
+ match=rest.MatchValue(value=filename_values[0])
711
+ ))
712
+ else:
713
+ # Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
714
+ conditions.append(rest.FieldCondition(
715
+ key="metadata.filename",
716
+ match=rest.MatchAny(any=filename_values)
717
+ ))
718
+ filter_summary['filenames'] = f"SHOULD: {filename_values}"
719
+
720
+ # Build final filter
721
+ if conditions:
722
+ # Always wrap conditions in a Filter object, even for single conditions
723
+ result_filter = rest.Filter(must=conditions)
724
+
725
+ # Print clean filter summary
726
+ print(f"✅ APPLIED FILTERS: {filter_summary}")
727
+ return result_filter, filter_summary
728
+ else:
729
+ print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
730
+ return None, {}
731
+
732
+ except Exception as e:
733
+ print(f"❌ FILTER BUILD ERROR: {str(e)}")
734
+ print(f"🔍 DEBUG: Original inferred filters keys: {list(inferred_filters.keys()) if isinstance(inferred_filters, dict) else 'Not a dict'}")
735
+ print(f"🔍 DEBUG: Original inferred filters content: {inferred_filters}")
736
+ print(f"🔍 DEBUG: Validated filters keys: {list(validated_filters.keys()) if isinstance(validated_filters, dict) else 'Not a dict'}")
737
+ print(f"🔍 DEBUG: Validated filters content: {validated_filters}")
738
+ # Return a safe fallback - no filter (search all documents)
739
+ return None, {}
740
+
741
+
742
+ class MetadataCache:
743
+ """Cache for vectorstore metadata to avoid repeated queries."""
744
+
745
+ def __init__(self):
746
+ self._cache = None
747
+ self._last_updated = None
748
+ self._cache_ttl = 3600 # 1 hour TTL
749
+
750
+ def get_metadata(self, vectorstore) -> dict:
751
+ """
752
+ Get metadata from cache or load it if not available/expired.
753
+
754
+ Args:
755
+ vectorstore: QdrantVectorStore instance
756
+
757
+ Returns:
758
+ Dictionary of available metadata values
759
+ """
760
+ import time
761
+
762
+ # Check if cache is valid
763
+ if (self._cache is not None and
764
+ self._last_updated is not None and
765
+ time.time() - self._last_updated < self._cache_ttl):
766
+ print(f"✅ METADATA CACHE: Using cached metadata")
767
+ return self._cache
768
+
769
+ try:
770
+ print(f"🔄 METADATA CACHE: Loading metadata from vectorstore...")
771
+
772
+ # Get collection info
773
+ try:
774
+ collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
775
+ print(f"✅ Collection info retrieved: {getattr(collection_info, 'name', 'unknown')}")
776
+ except Exception as e:
777
+ print(f"⚠️ Could not get collection info: {e}")
778
+
779
+ # Get ALL documents to extract complete metadata
780
+ print(f"📄 Scanning entire corpus for complete metadata extraction...")
781
+
782
+ # Get collection info to determine total size
783
+ try:
784
+ collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
785
+ total_points = getattr(collection_info, 'points_count', 0)
786
+ print(f"📊 Total documents in corpus: {total_points}")
787
+ except Exception as e:
788
+ print(f"⚠️ Could not get collection size: {e}")
789
+ total_points = 0
790
+
791
+ # Extract unique metadata values from ALL documents
792
+ sources = set()
793
+ years = set()
794
+ filenames = set()
795
+
796
+ # Try to use scroll to get all documents in batches
797
+ batch_size = 1000 # Process in batches to avoid memory issues
798
+ offset = None
799
+ processed_count = 0
800
+ scroll_success = False
801
+
802
+ try:
803
+ while True:
804
+ # Scroll through all documents
805
+ scroll_result = vectorstore._client.scroll(
806
+ collection_name=vectorstore.collection_name,
807
+ limit=batch_size,
808
+ offset=offset,
809
+ with_payload=True,
810
+ with_vectors=False # We only need metadata
811
+ )
812
+
813
+ points = scroll_result[0] # Get the points
814
+ if not points:
815
+ break # No more documents
816
+
817
+ # Process each document
818
+ for i, point in enumerate(points):
819
+ if hasattr(point, 'payload') and point.payload:
820
+ payload = point.payload
821
+
822
+ # Debug: Log structure of first few documents
823
+ if processed_count + i < 2: # Only log first 2 documents
824
+ print(f"🔍 DEBUG Document {processed_count + i + 1} payload structure:")
825
+ print(f" Payload keys: {list(payload.keys()) if isinstance(payload, dict) else 'Not a dict'}")
826
+ if isinstance(payload, dict) and 'metadata' in payload:
827
+ print(f" Metadata keys: {list(payload['metadata'].keys()) if isinstance(payload['metadata'], dict) else 'Not a dict'}")
828
+ elif isinstance(payload, dict):
829
+ print(f" Top-level keys: {list(payload.keys())}")
830
+ print(f" Payload type: {type(payload)}")
831
+ print(f" Payload sample: {str(payload)[:200]}...")
832
+ print()
833
+
834
+ # Try different metadata structures
835
+ found_metadata = False
836
+
837
+ # Structure 1: payload['metadata']['source']
838
+ if isinstance(payload, dict) and 'metadata' in payload:
839
+ metadata = payload['metadata']
840
+ if isinstance(metadata, dict):
841
+ if 'source' in metadata:
842
+ sources.add(metadata['source'])
843
+ found_metadata = True
844
+ if 'year' in metadata:
845
+ years.add(metadata['year'])
846
+ found_metadata = True
847
+ if 'filename' in metadata:
848
+ filenames.add(metadata['filename'])
849
+ found_metadata = True
850
+
851
+ # Structure 2: payload['source'] (direct)
852
+ if isinstance(payload, dict):
853
+ if 'source' in payload:
854
+ sources.add(payload['source'])
855
+ found_metadata = True
856
+ if 'year' in payload:
857
+ years.add(payload['year'])
858
+ found_metadata = True
859
+ if 'filename' in payload:
860
+ filenames.add(payload['filename'])
861
+ found_metadata = True
862
+
863
+ # Structure 3: Check for nested structures
864
+ if not found_metadata and isinstance(payload, dict):
865
+ # Look for any nested dict that might contain metadata
866
+ for key, value in payload.items():
867
+ if isinstance(value, dict):
868
+ if 'source' in value:
869
+ sources.add(value['source'])
870
+ found_metadata = True
871
+ if 'year' in value:
872
+ years.add(value['year'])
873
+ found_metadata = True
874
+ if 'filename' in value:
875
+ filenames.add(value['filename'])
876
+ found_metadata = True
877
+
878
+ processed_count += len(points)
879
+ progress_pct = (processed_count / total_points * 100) if total_points > 0 else 0
880
+ print(f"📄 Processed {processed_count}/{total_points} documents ({progress_pct:.1f}%)... (sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
881
+
882
+ # Update offset for next batch
883
+ offset = scroll_result[1] # Next offset
884
+ if offset is None:
885
+ break # No more documents
886
+
887
+ scroll_success = True
888
+ print(f"✅ Scroll method successful - processed {processed_count} documents")
889
+
890
+ except Exception as e:
891
+ print(f"❌ Scroll method failed: {e}")
892
+ print(f"🔄 Falling back to similarity search method...")
893
+
894
+ # Fallback: Use similarity search with multiple queries to get more coverage
895
+ fallback_queries = [
896
+ "", # Empty query
897
+ "audit", "report", "government", "ministry", "department",
898
+ "local", "consolidated", "annual", "financial", "budget",
899
+ "2020", "2021", "2022", "2023", "2024" # Year queries
900
+ ]
901
+
902
+ processed_count = 0
903
+ for query in fallback_queries:
904
+ try:
905
+ # Get documents for this query
906
+ docs = vectorstore.similarity_search(query, k=1000) # Get more per query
907
+
908
+ for j, doc in enumerate(docs):
909
+ if hasattr(doc, 'metadata') and doc.metadata:
910
+ # Debug: Log structure of first few documents in fallback
911
+ if processed_count + j < 3: # Only log first 3 documents per query
912
+ print(f"🔍 DEBUG Fallback Document {processed_count + j + 1} (query: '{query}') metadata structure:")
913
+ print(f" Metadata keys: {list(doc.metadata.keys()) if isinstance(doc.metadata, dict) else 'Not a dict'}")
914
+ print(f" Metadata type: {type(doc.metadata)}")
915
+ print(f" Metadata sample: {str(doc.metadata)[:200]}...")
916
+ print()
917
+
918
+ if 'source' in doc.metadata:
919
+ sources.add(doc.metadata['source'])
920
+ if 'year' in doc.metadata:
921
+ years.add(doc.metadata['year'])
922
+ if 'filename' in doc.metadata:
923
+ filenames.add(doc.metadata['filename'])
924
+
925
+ processed_count += len(docs)
926
+ print(f"📄 Fallback query '{query}': {len(docs)} docs (total: {processed_count}, sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
927
+
928
+ except Exception as query_error:
929
+ print(f"⚠️ Fallback query '{query}' failed: {query_error}")
930
+ continue
931
+
932
+ print(f"✅ Fallback method completed - processed {processed_count} documents")
933
+
934
+ print(f"✅ Completed scanning {processed_count} documents from entire corpus")
935
+
936
+ # Convert to sorted lists
937
+ metadata = {
938
+ 'sources': sorted(list(sources)),
939
+ 'years': sorted(list(years)),
940
+ 'filenames': sorted(list(filenames))
941
+ }
942
+
943
+ # Cache the results
944
+ self._cache = metadata
945
+ self._last_updated = time.time()
946
+
947
+ print(f"✅ Complete metadata extracted from entire corpus: {len(sources)} sources, {len(years)} years, {len(filenames)} files")
948
+
949
+ # Debug: Show what was actually found
950
+ if sources:
951
+ print(f"📁 Sources found: {sorted(list(sources))}")
952
+ else:
953
+ print(f"❌ No sources found - check metadata structure")
954
+
955
+ if years:
956
+ print(f"📅 Years found: {sorted(list(years))}")
957
+ else:
958
+ print(f"❌ No years found - check metadata structure")
959
+
960
+ if filenames:
961
+ print(f"📄 Filenames found: {sorted(list(filenames))[:10]}{'...' if len(filenames) > 10 else ''}")
962
+ else:
963
+ print(f"❌ No filenames found - check metadata structure")
964
+ return metadata
965
+
966
+ except Exception as e:
967
+ print(f"❌ Error extracting metadata: {e}")
968
+ return {'sources': [], 'years': [], 'filenames': []}
969
+
970
+ # Global metadata cache
971
+ _metadata_cache = MetadataCache()
972
+
973
+ def get_available_metadata(vectorstore) -> dict:
974
+ """Get available metadata values from the vectorstore efficiently."""
975
+ return _metadata_cache.get_metadata(vectorstore)
src/retrieval/hybrid.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hybrid search implementation combining vector and sparse retrieval."""
2
+
3
+ import json
4
+ import numpy as np
5
+ from typing import List, Dict, Any, Tuple
6
+ from pathlib import Path
7
+ from langchain.docstore.document import Document
8
+ from langchain_qdrant import QdrantVectorStore
9
+ from langchain_community.retrievers import BM25Retriever
10
+ from .filter import create_filter
11
+ import pickle
12
+ import os
13
+
14
+
15
+ class HybridRetriever:
16
+ """
17
+ Hybrid retrieval system combining vector search (dense) and BM25 (sparse) search.
18
+ Supports configurable search modes: vector_only, sparse_only, or hybrid.
19
+ """
20
+
21
+ def __init__(self, config: Dict[str, Any]):
22
+ """
23
+ Initialize hybrid retriever.
24
+
25
+ Args:
26
+ config: Configuration dictionary with hybrid search settings
27
+ """
28
+ self.config = config
29
+ self.bm25_retriever = None
30
+ self.documents = []
31
+ self._bm25_cache_file = None
32
+
33
+ def _get_bm25_cache_path(self) -> str:
34
+ """Get path for BM25 cache file."""
35
+ cache_dir = Path("cache/bm25")
36
+ cache_dir.mkdir(parents=True, exist_ok=True)
37
+ return str(cache_dir / "bm25_retriever.pkl")
38
+
39
+ def initialize_bm25(self, documents: List[Document], force_rebuild: bool = False) -> None:
40
+ """
41
+ Initialize BM25 retriever with documents.
42
+
43
+ Args:
44
+ documents: List of Document objects to index
45
+ force_rebuild: Whether to force rebuilding the BM25 index
46
+ """
47
+ self.documents = documents
48
+ self._bm25_cache_file = self._get_bm25_cache_path()
49
+
50
+ # Try to load cached BM25 retriever
51
+ if not force_rebuild and os.path.exists(self._bm25_cache_file):
52
+ try:
53
+ print("Loading cached BM25 retriever...")
54
+ with open(self._bm25_cache_file, 'rb') as f:
55
+ self.bm25_retriever = pickle.load(f)
56
+ print(f"✅ Loaded cached BM25 retriever with {len(self.documents)} documents")
57
+ return
58
+ except Exception as e:
59
+ print(f"⚠️ Failed to load cached BM25 retriever: {e}")
60
+ print("Building new BM25 index...")
61
+
62
+ # Build new BM25 retriever
63
+ print("Building BM25 index...")
64
+ try:
65
+ # Use langchain's BM25Retriever
66
+ self.bm25_retriever = BM25Retriever.from_documents(documents)
67
+
68
+ # Configure BM25 parameters
69
+ bm25_config = self.config.get("bm25", {})
70
+ k = bm25_config.get("top_k", 20)
71
+ self.bm25_retriever.k = k
72
+
73
+ # Cache the BM25 retriever
74
+ with open(self._bm25_cache_file, 'wb') as f:
75
+ pickle.dump(self.bm25_retriever, f)
76
+ print(f"✅ Built and cached BM25 retriever with {len(documents)} documents")
77
+
78
+ except Exception as e:
79
+ print(f"❌ Failed to build BM25 retriever: {e}")
80
+ print("BM25 search will be disabled")
81
+ self.bm25_retriever = None
82
+
83
+ def _filter_documents_by_metadata(
84
+ self,
85
+ documents: List[Document],
86
+ reports: List[str] = None,
87
+ sources: str = None,
88
+ subtype: List[str] = None,
89
+ year: List[str] = None
90
+ ) -> List[Document]:
91
+ """
92
+ Filter documents by metadata criteria.
93
+
94
+ Args:
95
+ documents: List of documents to filter
96
+ reports: List of specific report filenames
97
+ sources: Source category
98
+ subtype: List of subtypes
99
+ year: List of years
100
+
101
+ Returns:
102
+ Filtered list of documents
103
+ """
104
+ if not any([reports, sources, subtype, year]):
105
+ return documents
106
+
107
+ filtered_docs = []
108
+ for doc in documents:
109
+ metadata = doc.metadata
110
+
111
+ # Filter by reports
112
+ if reports:
113
+ filename = metadata.get('filename', '')
114
+ if not any(report in filename for report in reports):
115
+ continue
116
+
117
+ # Filter by sources
118
+ if sources:
119
+ doc_source = metadata.get('source', '')
120
+ if sources != doc_source:
121
+ continue
122
+
123
+ # Filter by subtype
124
+ if subtype:
125
+ doc_subtype = metadata.get('subtype', '')
126
+ if doc_subtype not in subtype:
127
+ continue
128
+
129
+ # Filter by year
130
+ if year:
131
+ doc_year = str(metadata.get('year', ''))
132
+ if doc_year not in year:
133
+ continue
134
+
135
+ filtered_docs.append(doc)
136
+
137
+ return filtered_docs
138
+
139
+ def _bm25_search(
140
+ self,
141
+ query: str,
142
+ k: int = 20,
143
+ reports: List[str] = None,
144
+ sources: str = None,
145
+ subtype: List[str] = None,
146
+ year: List[str] = None
147
+ ) -> List[Tuple[Document, float]]:
148
+ """
149
+ Perform BM25 sparse search.
150
+
151
+ Args:
152
+ query: Search query
153
+ k: Number of documents to retrieve
154
+ reports: List of specific report filenames
155
+ sources: Source category
156
+ subtype: List of subtypes
157
+ year: List of years
158
+
159
+ Returns:
160
+ List of (Document, score) tuples
161
+ """
162
+ if not self.bm25_retriever:
163
+ print("⚠️ BM25 retriever not available")
164
+ return []
165
+
166
+ try:
167
+ # Get BM25 results
168
+ self.bm25_retriever.k = k
169
+ bm25_docs = self.bm25_retriever.invoke(query)
170
+
171
+ # Apply metadata filtering
172
+ if any([reports, sources, subtype, year]):
173
+ bm25_docs = self._filter_documents_by_metadata(
174
+ bm25_docs, reports, sources, subtype, year
175
+ )
176
+
177
+ # BM25Retriever doesn't return scores directly, so we'll use placeholder scores
178
+ # In a production system, you'd want to access the actual BM25 scores
179
+ results = []
180
+ for i, doc in enumerate(bm25_docs):
181
+ # Assign decreasing scores based on rank (higher rank = higher score)
182
+ # Normalize to [0, 1] range for consistency with vector search
183
+ score = max(0.1, 1.0 - (i / max(len(bm25_docs), 1)))
184
+ results.append((doc, score))
185
+
186
+ return results
187
+
188
+ except Exception as e:
189
+ print(f"❌ BM25 search failed: {e}")
190
+ return []
191
+
192
+ def _vector_search(
193
+ self,
194
+ vectorstore: QdrantVectorStore,
195
+ query: str,
196
+ k: int = 20,
197
+ reports: List[str] = None,
198
+ sources: str = None,
199
+ subtype: List[str] = None,
200
+ year: List[str] = None
201
+ ) -> List[Tuple[Document, float]]:
202
+ """
203
+ Perform vector similarity search.
204
+
205
+ Args:
206
+ vectorstore: QdrantVectorStore instance
207
+ query: Search query
208
+ k: Number of documents to retrieve
209
+ reports: List of specific report filenames
210
+ sources: Source category
211
+ subtype: List of subtypes
212
+ year: List of years
213
+
214
+ Returns:
215
+ List of (Document, score) tuples
216
+ """
217
+ try:
218
+ # Create filter
219
+ filter_obj = create_filter(
220
+ reports=reports,
221
+ sources=sources,
222
+ subtype=subtype,
223
+ year=year
224
+ )
225
+
226
+ # Perform vector search
227
+ if filter_obj:
228
+ results = vectorstore.similarity_search_with_score(
229
+ query, k=k, filter=filter_obj
230
+ )
231
+ else:
232
+ results = vectorstore.similarity_search_with_score(query, k=k)
233
+
234
+ return results
235
+
236
+ except Exception as e:
237
+ print(f"❌ Vector search failed: {e}")
238
+ return []
239
+
240
+ def _normalize_scores(self, results: List[Tuple[Document, float]], method: str = "min_max") -> List[Tuple[Document, float]]:
241
+ """
242
+ Normalize scores to [0, 1] range.
243
+
244
+ Args:
245
+ results: List of (Document, score) tuples
246
+ method: Normalization method ('min_max' or 'z_score')
247
+
248
+ Returns:
249
+ List of (Document, normalized_score) tuples
250
+ """
251
+ if not results:
252
+ return results
253
+
254
+ scores = [score for _, score in results]
255
+
256
+ if method == "min_max":
257
+ min_score = min(scores)
258
+ max_score = max(scores)
259
+ if max_score == min_score:
260
+ normalized_results = [(doc, 1.0) for doc, _ in results]
261
+ else:
262
+ normalized_results = [
263
+ (doc, (score - min_score) / (max_score - min_score))
264
+ for doc, score in results
265
+ ]
266
+ elif method == "z_score":
267
+ mean_score = np.mean(scores)
268
+ std_score = np.std(scores)
269
+ if std_score == 0:
270
+ normalized_results = [(doc, 1.0) for doc, _ in results]
271
+ else:
272
+ normalized_results = [
273
+ (doc, max(0, (score - mean_score) / std_score))
274
+ for doc, score in results
275
+ ]
276
+ else:
277
+ normalized_results = results
278
+
279
+ return normalized_results
280
+
281
+ def _combine_results(
282
+ self,
283
+ vector_results: List[Tuple[Document, float]],
284
+ bm25_results: List[Tuple[Document, float]],
285
+ alpha: float = 0.5
286
+ ) -> List[Tuple[Document, float]]:
287
+ """
288
+ Combine vector and BM25 results with weighted scoring.
289
+
290
+ Args:
291
+ vector_results: Vector search results
292
+ bm25_results: BM25 search results
293
+ alpha: Weight for vector scores (1-alpha for BM25 scores)
294
+
295
+ Returns:
296
+ Combined and ranked results
297
+ """
298
+ # Normalize scores
299
+ vector_results = self._normalize_scores(vector_results)
300
+ bm25_results = self._normalize_scores(bm25_results)
301
+
302
+ # Create document ID mapping for both result sets
303
+ vector_docs = {id(doc): (doc, score) for doc, score in vector_results}
304
+ bm25_docs = {id(doc): (doc, score) for doc, score in bm25_results}
305
+
306
+ # Combine scores
307
+ combined_scores = {}
308
+ all_doc_ids = set(vector_docs.keys()) | set(bm25_docs.keys())
309
+
310
+ for doc_id in all_doc_ids:
311
+ vector_score = vector_docs.get(doc_id, (None, 0.0))[1]
312
+ bm25_score = bm25_docs.get(doc_id, (None, 0.0))[1]
313
+
314
+ # Weighted combination
315
+ combined_score = alpha * vector_score + (1 - alpha) * bm25_score
316
+
317
+ # Get document object
318
+ doc = vector_docs.get(doc_id, bm25_docs.get(doc_id))[0]
319
+ combined_scores[doc_id] = (doc, combined_score)
320
+
321
+ # Sort by combined score (descending)
322
+ sorted_results = sorted(
323
+ combined_scores.values(),
324
+ key=lambda x: x[1],
325
+ reverse=True
326
+ )
327
+
328
+ return sorted_results
329
+
330
+ def retrieve(
331
+ self,
332
+ vectorstore: QdrantVectorStore,
333
+ query: str,
334
+ mode: str = "hybrid",
335
+ reports: List[str] = None,
336
+ sources: str = None,
337
+ subtype: List[str] = None,
338
+ year: List[str] = None,
339
+ alpha: float = 0.5,
340
+ k: int = None
341
+ ) -> List[Document]:
342
+ """
343
+ Retrieve documents using the specified search mode.
344
+
345
+ Args:
346
+ vectorstore: QdrantVectorStore instance
347
+ query: Search query
348
+ mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
349
+ reports: List of specific report filenames
350
+ sources: Source category
351
+ subtype: List of subtypes
352
+ year: List of years
353
+ alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
354
+ k: Number of documents to retrieve
355
+
356
+ Returns:
357
+ List of relevant Document objects
358
+ """
359
+ if k is None:
360
+ k = self.config.get("retriever", {}).get("top_k", 20)
361
+
362
+ results = []
363
+
364
+ if mode == "vector_only":
365
+ # Vector search only
366
+ vector_results = self._vector_search(
367
+ vectorstore, query, k, reports, sources, subtype, year
368
+ )
369
+ results = [(doc, score) for doc, score in vector_results]
370
+
371
+ elif mode == "sparse_only":
372
+ # BM25 search only
373
+ bm25_results = self._bm25_search(
374
+ query, k, reports, sources, subtype, year
375
+ )
376
+ results = [(doc, score) for doc, score in bm25_results]
377
+
378
+ elif mode == "hybrid":
379
+ # Hybrid search - combine both
380
+ # Get more results from each method to have better fusion
381
+ retrieval_k = min(k * 2, 50) # Get more candidates for fusion
382
+
383
+ vector_results = self._vector_search(
384
+ vectorstore, query, retrieval_k, reports, sources, subtype, year
385
+ )
386
+ bm25_results = self._bm25_search(
387
+ query, retrieval_k, reports, sources, subtype, year
388
+ )
389
+
390
+ results = self._combine_results(vector_results, bm25_results, alpha)
391
+
392
+ else:
393
+ raise ValueError(f"Unknown search mode: {mode}")
394
+
395
+ # Limit to top k results
396
+ results = results[:k]
397
+
398
+ # Return just the documents
399
+ return [doc for doc, score in results]
400
+
401
+ def retrieve_with_scores(
402
+ self,
403
+ vectorstore: QdrantVectorStore,
404
+ query: str,
405
+ mode: str = "hybrid",
406
+ reports: List[str] = None,
407
+ sources: str = None,
408
+ subtype: List[str] = None,
409
+ year: List[str] = None,
410
+ alpha: float = 0.5,
411
+ k: int = None
412
+ ) -> List[Tuple[Document, float]]:
413
+ """
414
+ Retrieve documents with scores using the specified search mode.
415
+
416
+ Args:
417
+ vectorstore: QdrantVectorStore instance
418
+ query: Search query
419
+ mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
420
+ reports: List of specific report filenames
421
+ sources: Source category
422
+ subtype: List of subtypes
423
+ year: List of years
424
+ alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
425
+ k: Number of documents to retrieve
426
+
427
+ Returns:
428
+ List of (Document, score) tuples
429
+ """
430
+ if k is None:
431
+ k = self.config.get("retriever", {}).get("top_k", 20)
432
+
433
+ results = []
434
+
435
+ if mode == "vector_only":
436
+ # Vector search only
437
+ results = self._vector_search(
438
+ vectorstore, query, k, reports, sources, subtype, year
439
+ )
440
+
441
+ elif mode == "sparse_only":
442
+ # BM25 search only
443
+ results = self._bm25_search(
444
+ query, k, reports, sources, subtype, year
445
+ )
446
+
447
+ elif mode == "hybrid":
448
+ # Hybrid search - combine both
449
+ # Get more results from each method to have better fusion
450
+ retrieval_k = min(k * 2, 50) # Get more candidates for fusion
451
+
452
+ vector_results = self._vector_search(
453
+ vectorstore, query, retrieval_k, reports, sources, subtype, year
454
+ )
455
+ bm25_results = self._bm25_search(
456
+ query, retrieval_k, reports, sources, subtype, year
457
+ )
458
+
459
+ results = self._combine_results(vector_results, bm25_results, alpha)
460
+
461
+ else:
462
+ raise ValueError(f"Unknown search mode: {mode}")
463
+
464
+ # Limit to top k results
465
+ return results[:k]
466
+
467
+
468
+ def get_available_search_modes() -> List[str]:
469
+ """Get list of available search modes."""
470
+ return ["vector_only", "sparse_only", "hybrid"]
471
+
472
+
473
+ def get_search_mode_description() -> Dict[str, str]:
474
+ """Get descriptions for each search mode."""
475
+ return {
476
+ "vector_only": "Semantic search using dense embeddings - good for conceptual matching",
477
+ "sparse_only": "Keyword search using BM25 - good for exact term matching",
478
+ "hybrid": "Combined semantic and keyword search - balanced approach"
479
+ }
src/vectorstore.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector store management and operations."""
2
+ from pathlib import Path
3
+ from typing import Dict, Any, List, Optional
4
+
5
+
6
+ import torch
7
+ from langchain_qdrant import QdrantVectorStore
8
+ from langchain.docstore.document import Document
9
+ from langchain_core.embeddings import Embeddings
10
+ from sentence_transformers import SentenceTransformer
11
+ from langchain_huggingface import HuggingFaceEmbeddings
12
+
13
+
14
+ class MatryoshkaEmbeddings(Embeddings):
15
+ """Custom embeddings class that supports Matryoshka dimension truncation."""
16
+
17
+ def __init__(self, model_name: str, truncate_dim: int = None, **kwargs):
18
+ """
19
+ Initialize Matryoshka embeddings.
20
+
21
+ Args:
22
+ model_name: Name of the model
23
+ truncate_dim: Dimension to truncate to (for Matryoshka models)
24
+ **kwargs: Additional arguments (ignored for Matryoshka models)
25
+ """
26
+ self.model_name = model_name
27
+ self.truncate_dim = truncate_dim
28
+
29
+ if truncate_dim and "matryoshka" in model_name.lower():
30
+ # Use SentenceTransformer directly for Matryoshka models
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
33
+ print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
34
+ else:
35
+ # Use standard HuggingFaceEmbeddings
36
+ self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
37
+
38
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
39
+ """Embed documents."""
40
+ if self.truncate_dim and "matryoshka" in self.model_name.lower():
41
+ embeddings = self.model.encode(texts, normalize_embeddings=True)
42
+ return embeddings.tolist()
43
+ else:
44
+ return self.model.embed_documents(texts)
45
+
46
+ def embed_query(self, text: str) -> List[float]:
47
+ """Embed query."""
48
+ if self.truncate_dim and "matryoshka" in self.model_name.lower():
49
+ embedding = self.model.encode([text], normalize_embeddings=True)
50
+ return embedding[0].tolist()
51
+ else:
52
+ return self.model.embed_query(text)
53
+
54
+
55
+ class VectorStoreManager:
56
+ """Manages vector store operations and connections."""
57
+
58
+ def __init__(self, config: Dict[str, Any]):
59
+ """
60
+ Initialize vector store manager.
61
+
62
+ Args:
63
+ config: Configuration dictionary
64
+ """
65
+ self.config = config
66
+ self.embeddings = self._create_embeddings()
67
+ self.vectorstore = None
68
+
69
+ # Define metadata fields that need payload indexes for filtering
70
+ self.metadata_fields = [
71
+ ("metadata.year", "keyword"),
72
+ ("metadata.source", "keyword"),
73
+ ("metadata.filename", "keyword"),
74
+ # Add more metadata fields as needed
75
+ ]
76
+
77
+ def _create_embeddings(self) -> HuggingFaceEmbeddings:
78
+ """Create embeddings model from configuration."""
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+
81
+ model_name = self.config["retriever"]["model"]
82
+ normalize = self.config["retriever"]["normalize"]
83
+
84
+ model_kwargs = {"device": device}
85
+ encode_kwargs = {
86
+ "normalize_embeddings": normalize,
87
+ "batch_size": 100,
88
+ }
89
+
90
+ # For Matryoshka models, check if we need to truncate dimensions
91
+ if "matryoshka" in model_name.lower():
92
+ # Check if we have a specific dimension requirement
93
+ collection_name = self.config.get("qdrant", {}).get("collection_name", "")
94
+
95
+ if "modernbert-embed-base-akryl-matryoshka" in collection_name:
96
+ # This collection expects 768 dimensions
97
+ truncate_dim = 768
98
+ print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
99
+
100
+ # Use custom MatryoshkaEmbeddings
101
+ embeddings = MatryoshkaEmbeddings(
102
+ model_name=model_name,
103
+ truncate_dim=truncate_dim,
104
+ model_kwargs=model_kwargs,
105
+ encode_kwargs=encode_kwargs,
106
+ show_progress=True,
107
+ )
108
+ return embeddings
109
+
110
+ # Use standard HuggingFaceEmbeddings for non-Matryoshka models
111
+ embeddings = HuggingFaceEmbeddings(
112
+ model_name=model_name,
113
+ model_kwargs=model_kwargs,
114
+ encode_kwargs=encode_kwargs,
115
+ show_progress=True,
116
+ )
117
+
118
+ return embeddings
119
+
120
+ def ensure_metadata_indexes(self) -> None:
121
+ """
122
+ Create payload indexes for all required metadata fields.
123
+ This ensures filtering works properly, especially in Qdrant Cloud.
124
+ """
125
+ if not self.vectorstore:
126
+ return
127
+
128
+ qdrant_config = self.config["qdrant"]
129
+ collection_name = qdrant_config["collection_name"]
130
+
131
+ for field_name, field_type in self.metadata_fields:
132
+ try:
133
+ self.vectorstore.client.create_payload_index(
134
+ collection_name=collection_name,
135
+ field_name=field_name,
136
+ field_type=field_type
137
+ )
138
+ print(f"Created payload index for {field_name} ({field_type})")
139
+ except Exception as e:
140
+ # Index might already exist or other error - log but continue
141
+ print(f"Index creation for {field_name} ({field_type}): {str(e)}")
142
+
143
+ def connect_to_existing(self, force_recreate: bool = False) -> QdrantVectorStore:
144
+ """
145
+ Connect to existing Qdrant collection.
146
+
147
+ Args:
148
+ force_recreate: If True, recreate the collection if dimension mismatch occurs
149
+
150
+ Returns:
151
+ QdrantVectorStore instance
152
+ """
153
+ qdrant_config = self.config["qdrant"]
154
+
155
+ kwargs_qdrant = {
156
+ "url": qdrant_config["url"],
157
+ "collection_name": qdrant_config["collection_name"],
158
+ "prefer_grpc": qdrant_config.get("prefer_grpc", True),
159
+ "api_key": qdrant_config.get("api_key", None),
160
+ }
161
+
162
+ if force_recreate:
163
+ kwargs_qdrant["force_recreate"] = True
164
+
165
+ self.vectorstore = QdrantVectorStore.from_existing_collection(
166
+ embedding=self.embeddings,
167
+ **kwargs_qdrant
168
+ )
169
+
170
+ # Ensure payload indexes exist for metadata filtering
171
+ self.ensure_metadata_indexes()
172
+
173
+ return self.vectorstore
174
+
175
+ def create_from_documents(self, documents: List[Document]) -> QdrantVectorStore:
176
+ """
177
+ Create new Qdrant collection from documents.
178
+
179
+ Args:
180
+ documents: List of Document objects
181
+
182
+ Returns:
183
+ QdrantVectorStore instance
184
+ """
185
+ qdrant_config = self.config["qdrant"]
186
+
187
+ kwargs_qdrant = {
188
+ "url": qdrant_config["url"],
189
+ "collection_name": qdrant_config["collection_name"],
190
+ "prefer_grpc": qdrant_config.get("prefer_grpc", True),
191
+ "api_key": qdrant_config.get("api_key", None),
192
+ }
193
+
194
+ self.vectorstore = QdrantVectorStore.from_documents(
195
+ documents=documents,
196
+ embedding=self.embeddings,
197
+ **kwargs_qdrant
198
+ )
199
+
200
+ # Ensure payload indexes exist for metadata filtering
201
+ self.ensure_metadata_indexes()
202
+
203
+ return self.vectorstore
204
+
205
+ def delete_collection(self) -> None:
206
+ """
207
+ Delete the current Qdrant collection.
208
+
209
+ Returns:
210
+ QdrantVectorStore instance
211
+ """
212
+ qdrant_config = self.config["qdrant"]
213
+ collection_name = qdrant_config.get("collection_name")
214
+
215
+ self.vectorstore.client.delete_collection(
216
+ collection_name=collection_name
217
+ )
218
+
219
+ return self.vectorstore
220
+
221
+ def get_vectorstore(self) -> Optional[QdrantVectorStore]:
222
+ """Get current vectorstore instance."""
223
+ return self.vectorstore
224
+
225
+
226
+ def get_local_qdrant(config: Dict[str, Any]) -> QdrantVectorStore:
227
+ """
228
+ Get local Qdrant vector store (legacy function for compatibility).
229
+
230
+ Args:
231
+ config: Configuration dictionary
232
+
233
+ Returns:
234
+ QdrantVectorStore instance
235
+ """
236
+ manager = VectorStoreManager(config)
237
+ return manager.connect_to_existing()
238
+
239
+
240
+ def create_vectorstore(config: Dict[str, Any], documents: List[Document]) -> QdrantVectorStore:
241
+ """
242
+ Create new vector store from documents.
243
+
244
+ Args:
245
+ config: Configuration dictionary
246
+ documents: List of Document objects
247
+
248
+ Returns:
249
+ QdrantVectorStore instance
250
+ """
251
+ manager = VectorStoreManager(config)
252
+ return manager.create_from_documents(documents)
253
+
254
+
255
+ def get_embeddings_model(config: Dict[str, Any]) -> HuggingFaceEmbeddings:
256
+ """
257
+ Create embeddings model from configuration (legacy function).
258
+
259
+ Args:
260
+ config: Configuration dictionary
261
+
262
+ Returns:
263
+ HuggingFaceEmbeddings instance
264
+ """
265
+ manager = VectorStoreManager(config)
266
+ return manager.embeddings