audit_assistant / src /pipeline.py
akryldigital's picture
fix langchain_core.documents
f30b993 verified
"""Main pipeline orchestrator for the Audit QA system."""
import os
import time
import traceback
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
try:
from langchain.docstore.document import Document
except:
from langchain_core.documents import Document
from .logging import log_error
from .llm.loader import chunks_to_documents
from .vectorstore import VectorStoreManager
from .reporting.service import ReportService
from .retrieval.context import ContextRetriever
from .llm.adapters import LLMRegistry, get_llm_client
from .llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
from .config.loader import load_config, get_embedding_model_for_collection
from .retrieval.filter import get_available_metadata, infer_filters_from_query
try:
from langchain_openai import ChatOpenAI
LANGCHAIN_OPENAI_AVAILABLE = True
except ImportError:
LANGCHAIN_OPENAI_AVAILABLE = False
@dataclass
class PipelineResult:
"""Result of pipeline execution."""
answer: str
sources: List[Document]
execution_time: float
metadata: Dict[str, Any]
query: str = "" # Add default value for query
def __post_init__(self):
"""Post-initialization processing."""
if not self.query:
self.query = "Unknown query"
class PipelineManager:
"""Main pipeline manager for the RAG system."""
def __init__(self, config: dict = None):
"""
Initialize the pipeline manager.
"""
self.chunks = None
self.llm_client = None
self.config = config or {}
self.report_service = None
self.vectorstore_manager = None
self.context_retriever = None # Initialize as None
# Initialize components
self._initialize_components()
def update_config(self, new_config: dict):
"""
Update the pipeline configuration.
This is useful for experiments that need different settings.
"""
if not isinstance(new_config, dict):
return
# Deep merge the new config with existing config
def deep_merge(base_dict, update_dict):
for key, value in update_dict.items():
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
deep_merge(base_dict[key], value)
else:
base_dict[key] = value
deep_merge(self.config, new_config)
# Auto-infer embedding model from collection name if not "docling"
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
if collection_name != 'docling':
inferred_model = get_embedding_model_for_collection(collection_name)
if inferred_model:
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
if 'retriever' not in self.config:
self.config['retriever'] = {}
self.config['retriever']['model'] = inferred_model
# Set default normalize parameter if not present
if 'normalize' not in self.config['retriever']:
self.config['retriever']['normalize'] = True
# Also update vectorstore config if it exists
if 'vectorstore' in self.config:
self.config['vectorstore']['embedding_model'] = inferred_model
print(f"🔧 CONFIG UPDATED: Pipeline config updated with experiment settings")
# Re-initialize vectorstore manager with updated config
self._reinitialize_vectorstore_manager()
def _reinitialize_vectorstore_manager(self):
"""Re-initialize vectorstore manager with current config."""
try:
self.vectorstore_manager = VectorStoreManager(self.config)
print("🔄 VectorStore manager re-initialized with updated config")
except Exception as e:
print(f"❌ Error re-initializing vectorstore manager: {e}")
def _get_reranker_model_name(self) -> str:
"""
Get the reranker model name from configuration.
Returns:
Reranker model name or default
"""
return (
self.config.get('retrieval', {}).get('reranker_model') or
self.config.get('ranker', {}).get('model') or
self.config.get('reranker_model') or
'BAAI/bge-reranker-v2-m3'
)
def _initialize_components(self):
"""Initialize pipeline components."""
try:
# Load config if not provided
if not self.config:
self.config = load_config()
# Validate config structure
if not isinstance(self.config, dict):
raise ValueError(f"Config must be a dict, got {type(self.config)}")
# Ensure retriever config exists
if 'retriever' not in self.config:
self.config['retriever'] = {}
if 'model' not in self.config['retriever']:
raise ValueError("Config must have 'retriever.model' specified")
# Auto-infer embedding model from collection name if not "docling"
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
if collection_name != 'docling':
inferred_model = get_embedding_model_for_collection(collection_name)
if inferred_model:
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
if 'retriever' not in self.config:
self.config['retriever'] = {}
self.config['retriever']['model'] = inferred_model
# Set default normalize parameter if not present
if 'normalize' not in self.config['retriever']:
self.config['retriever']['normalize'] = True
# Also update vectorstore config if it exists
if 'vectorstore' in self.config:
self.config['vectorstore']['embedding_model'] = inferred_model
# Initialize vectorstore manager - this might fail if model loading fails
try:
self.vectorstore_manager = VectorStoreManager(self.config)
print("✅ VectorStoreManager initialized successfully")
except Exception as vs_error:
print(f"❌ Error initializing VectorStoreManager: {vs_error}")
traceback.print_exc()
self.vectorstore_manager = None
raise # Re-raise to be caught by outer try-except
self.llm_manager = LLMRegistry()
# Try to get LLM client using the correct method
self.llm_client = None
try:
# Try using get_adapter method (most likely correct)
self.llm_client = self.llm_manager.get_adapter("openai")
print("✅ LLM CLIENT: Initialized using get_adapter method")
except Exception as e:
try:
# Try direct instantiation with config
self.llm_client = get_llm_client("openai", self.config)
print("✅ LLM CLIENT: Initialized using direct get_llm_client function with config")
except Exception as e2:
print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
# Try to create a simple LLM client directly
try:
if LANGCHAIN_OPENAI_AVAILABLE:
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
if api_key:
self.llm_client = ChatOpenAI(
model="gpt-3.5-turbo",
api_key=api_key,
temperature=0.1,
max_tokens=1000
)
print("✅ LLM CLIENT: Initialized using direct ChatOpenAI")
else:
print("❌ LLM CLIENT: No API key available")
else:
print("❌ LLM CLIENT: langchain-openai not available")
except Exception as e3:
print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
self.llm_client = None
# Load system prompt
self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
# Initialize report service
try:
self.report_service = ReportService()
except Exception as e:
print(f"Warning: Could not initialize report service: {e}")
self.report_service = None
except Exception as e:
print(f"❌ Error initializing components: {e}")
traceback.print_exc()
# Don't set vectorstore_manager to None if it was already set
if not hasattr(self, 'vectorstore_manager') or self.vectorstore_manager is None:
self.vectorstore_manager = None
raise # Re-raise to allow caller to handle
def test_retrieval(
self,
query: str,
reports: List[str] = None,
sources: str = None,
subtype: List[str] = None,
k: int = None,
search_mode: str = None,
search_alpha: float = None,
use_reranking: bool = True
) -> Dict[str, Any]:
"""
Test retrieval only without LLM inference.
Args:
query: User query
reports: List of specific report filenames
sources: Source category
subtype: List of subtypes
k: Number of documents to retrieve
search_mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
search_alpha: Weight for vector scores in hybrid mode
use_reranking: Whether to use reranking
Returns:
Dictionary with retrieval results and metadata
"""
start_time = time.time()
try:
# Set default search parameters if not provided
if search_mode is None:
search_mode = self.config.get("hybrid", {}).get("default_mode", "vector_only")
if search_alpha is None:
search_alpha = self.config.get("hybrid", {}).get("default_alpha", 0.5)
# Get vector store
vectorstore = self.vectorstore_manager.get_vectorstore()
if not vectorstore:
raise ValueError(
"Vector store not available. Call connect_vectorstore() or create_vectorstore() first."
)
# Retrieve context with scores for test retrieval
context_docs_with_scores = self.context_retriever.retrieve_with_scores(
vectorstore=vectorstore,
query=query,
reports=reports,
sources=sources,
subtype=subtype,
k=k,
search_mode=search_mode,
alpha=search_alpha,
)
# Extract documents and scores
context_docs = [doc for doc, score in context_docs_with_scores]
context_scores = [score for doc, score in context_docs_with_scores]
execution_time = time.time() - start_time
# Format results with actual scores
results = []
for i, (doc, score) in enumerate(zip(context_docs, context_scores)):
results.append({
"rank": i + 1,
"content": doc.page_content, # Return full content without truncation
"metadata": doc.metadata,
"score": score if score is not None else 0.0
})
return {
"results": results,
"num_results": len(results),
"execution_time": execution_time,
"search_mode": search_mode,
"search_alpha": search_alpha,
"query": query
}
except Exception as e:
print(f"❌ Error during retrieval test: {e}")
log_error(e, {"component": "retrieval_test", "query": query})
return {
"results": [],
"num_results": 0,
"execution_time": time.time() - start_time,
"error": str(e),
"search_mode": search_mode or "unknown",
"search_alpha": search_alpha or 0.5,
"query": query
}
def connect_vectorstore(self, force_recreate: bool = False) -> bool:
"""
Connect to existing vector store.
Args:
force_recreate: If True, recreate the collection if dimension mismatch occurs
Returns:
True if successful, False otherwise
"""
# Check if vectorstore_manager is initialized
if self.vectorstore_manager is None:
print("❌ Vector store manager is not initialized")
print("🔄 Attempting to initialize vector store manager...")
try:
self._initialize_components()
if self.vectorstore_manager is None:
print("❌ Failed to initialize vector store manager")
return False
except Exception as init_error:
print(f"❌ Error initializing vector store manager: {init_error}")
traceback.print_exc()
return False
try:
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=force_recreate)
if vectorstore:
print("✅ Connected to vector store")
return True
else:
print("❌ Failed to connect to vector store")
return False
except Exception as e:
print(f"❌ Error connecting to vector store: {e}")
log_error(e, {"component": "vectorstore_connection"})
traceback.print_exc()
# If it's a dimension mismatch error, try with force_recreate
if "dimensions" in str(e).lower() and not force_recreate:
print("🔄 Dimension mismatch detected, attempting to recreate collection...")
try:
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=True)
if vectorstore:
print("✅ Connected to vector store (recreated)")
return True
except Exception as recreate_error:
print(f"❌ Failed to recreate vector store: {recreate_error}")
log_error(recreate_error, {"component": "vectorstore_recreation"})
return False
def create_vectorstore(self) -> bool:
"""
Create new vector store from chunks.
Returns:
True if successful, False otherwise
"""
try:
if not self.chunks:
raise ValueError("No chunks available for vector store creation")
documents = chunks_to_documents(self.chunks)
self.vectorstore_manager.create_from_documents(documents)
print("✅ Vector store created successfully")
return True
except Exception as e:
print(f"❌ Error creating vector store: {e}")
log_error(e, {"component": "vectorstore_creation"})
return False
def create_audit_prompt(self, query: str, context_docs: List[Document]) -> str:
"""Create a prompt for the LLM to generate an answer."""
try:
# Ensure query is not None
if not query or not isinstance(query, str) or query.strip() == "":
return "Error: No query provided"
# Ensure context_docs is not None and is a list
if context_docs is None:
context_docs = []
# Filter out None documents and ensure they have content
valid_docs = []
for doc in context_docs:
if doc is not None:
if hasattr(doc, 'page_content') and doc.page_content and isinstance(doc.page_content, str):
valid_docs.append(doc)
elif isinstance(doc, str) and doc.strip():
valid_docs.append(doc)
# Create context string
if valid_docs:
context_parts = []
for i, doc in enumerate(valid_docs, 1):
if hasattr(doc, 'page_content') and doc.page_content:
context_parts.append(f"Doc {i}: {doc.page_content}")
elif isinstance(doc, str) and doc.strip():
context_parts.append(f"Doc {i}: {doc}")
context_string = "\n\n".join(context_parts)
else:
context_string = "No relevant context found."
# Create the prompt
prompt = f"""
{self.system_prompt}
Context:
{context_string}
Query: {query}
Answer:"""
return prompt
except Exception as e:
print(f"Error creating audit prompt: {e}")
return f"Error creating prompt: {e}"
def _generate_answer(self, prompt: str) -> str:
"""Generate answer using the LLM."""
try:
if not prompt or not isinstance(prompt, str) or prompt.strip() == "":
return "Error: No prompt provided"
# Ensure LLM client is available
if not self.llm_client:
return "Error: LLM client not available"
# Generate response using the correct method
if hasattr(self.llm_client, 'generate'):
# Use the generate method (for adapters)
response = self.llm_client.generate([{"role": "user", "content": prompt}])
# Extract content from LLMResponse
if hasattr(response, 'content'):
answer = response.content
else:
answer = str(response)
elif hasattr(self.llm_client, 'invoke'):
# Use the invoke method (for direct LangChain models)
response = self.llm_client.invoke(prompt)
# Extract content safely
if hasattr(response, 'content') and response.content is not None:
answer = response.content
elif isinstance(response, str) and response.strip():
answer = response
else:
answer = str(response) if response is not None else "Error: LLM returned None response"
else:
return "Error: LLM client has no generate or invoke method"
# Ensure answer is not None and is a string
if answer is None or not isinstance(answer, str):
return "Error: LLM returned invalid response"
return answer.strip()
except Exception as e:
print(f"Error generating answer: {e}")
return f"Error generating answer: {e}"
def run(
self,
query: str,
reports: List[str] = None,
sources: List[str] = None,
subtype: List[str] = None,
llm_provider: str = None,
use_reranking: bool = True,
search_mode: str = None,
search_alpha: float = None,
auto_infer_filters: bool = True,
filters: Dict[str, Any] = None,
) -> PipelineResult:
"""
Run the complete RAG pipeline.
Args:
query: User query
reports: List of specific report filenames
sources: Source category filter
subtype: List of subtypes/filenames
llm_provider: LLM provider to use
use_reranking: Whether to use reranking
search_mode: Search mode (vector, sparse, hybrid)
search_alpha: Alpha value for hybrid search
auto_infer_filters: Whether to auto-infer filters from query
Returns:
PipelineResult object
"""
try:
# Validate input
if not query or not isinstance(query, str) or query.strip() == "":
return PipelineResult(
answer="Error: Invalid query provided",
sources=[],
execution_time=0.0,
metadata={'error': 'Invalid query'},
query=query
)
# Ensure lists are not None
if reports is None:
reports = []
if subtype is None:
subtype = []
start_time = time.time()
# Auto-infer filters if enabled and no explicit filters provided
inferred_filters = {}
filters_applied = False
qdrant_filter = None # Add this
# ALWAYS extract filter parameters from the filters dict first
# These need to be defined before the conditional logic
year = filters.get('year', []) if filters else []
district = filters.get('district', []) if filters else []
filenames = filters.get('filenames', []) if filters else []
sources_from_filters = filters.get('sources', []) if filters else []
# Use sources from filters dict if not provided directly
if sources_from_filters and not sources:
sources = sources_from_filters
# Check if any explicit filters were provided
has_explicit_filters = any([reports, sources, subtype, year, district, filenames])
if has_explicit_filters:
print(f"✅ EXPLICIT FILTERS: year={year}, district={district}, sources={sources}, filenames={filenames}")
filters_applied = True
elif auto_infer_filters:
print(f"🤖 AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
try:
# Get available metadata
available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
# Infer filters from query - this returns a Qdrant filter
qdrant_filter, filter_summary = infer_filters_from_query(
query=query,
available_metadata=available_metadata,
llm_client=self.llm_client
)
if qdrant_filter:
print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
filters_applied = True
else:
print(f"⚠️ NO QDRANT FILTER: Could not build Qdrant filter from query")
except Exception as e:
print(f"❌ AUTO-INFERENCE FAILED: {e}")
qdrant_filter = None
else:
print(f"⚠️ NO FILTERS: No explicit filters and auto-inference disabled")
# Get vectorstore
vectorstore = self.vectorstore_manager.get_vectorstore()
if not vectorstore:
return PipelineResult(
answer="Error: Vector store not available",
sources=[],
execution_time=0.0,
metadata={'error': 'Vector store not available'},
query=query
)
# Initialize context retriever if not already done
if not hasattr(self, 'context_retriever') or self.context_retriever is None:
# Get the actual vectorstore object
vectorstore_obj = self.vectorstore_manager.get_vectorstore()
if vectorstore_obj is None:
print("❌ ERROR: Vectorstore is None, cannot initialize ContextRetriever")
return None
self.context_retriever = ContextRetriever(vectorstore_obj, self.config)
print("✅ ContextRetriever initialized successfully")
# Debug config access
print(f" CONFIG DEBUG: Full config keys: {list(self.config.keys()) if isinstance(self.config, dict) else 'Not a dict'}")
print(f"🔍 CONFIG DEBUG: Retriever config: {self.config.get('retriever', {})}")
print(f"🔍 CONFIG DEBUG: Retrieval config: {self.config.get('retrieval', {})}")
print(f"🔍 CONFIG DEBUG: use_reranking from config: {self.config.get('retrieval', {}).get('use_reranking', 'NOT_FOUND')}")
# Get the correct top_k value
# Priority: experiment config > retriever config > default
top_k = (
self.config.get('retrieval', {}).get('top_k') or
self.config.get('retriever', {}).get('top_k') or
5
)
# Get reranking setting
use_reranking = self.config.get('retrieval', {}).get('use_reranking', False)
print(f"🔍 CONFIG DEBUG: Final top_k: {top_k}")
print(f"🔍 CONFIG DEBUG: Final use_reranking: {use_reranking}")
# Retrieve context using the context retriever
context_docs = self.context_retriever.retrieve_context(
query=query,
k=top_k,
reports=reports,
sources=sources,
subtype=subtype,
year=year,
district=district,
filenames=filenames,
use_reranking=use_reranking,
qdrant_filter=qdrant_filter
)
# Ensure context_docs is not None
if context_docs is None:
context_docs = []
# Generate answer
answer = self._generate_answer(self.create_audit_prompt(query, context_docs))
execution_time = time.time() - start_time
# Create result with comprehensive metadata
result = PipelineResult(
answer=answer,
sources=context_docs,
execution_time=execution_time,
metadata={
'llm_provider': llm_provider,
'use_reranking': use_reranking,
'search_mode': search_mode,
'search_alpha': search_alpha,
'auto_infer_filters': auto_infer_filters,
'filters_applied': filters_applied,
'with_filtering': filters_applied,
'filter_conditions': {
'reports': reports,
'sources': sources,
'subtype': subtype
},
'inferred_filters': inferred_filters,
'applied_filters': {
'reports': reports,
'sources': sources,
'subtype': subtype
},
# Store filter and reranking metadata
'filter_details': {
'explicit_filters': {
'reports': reports,
'sources': sources,
'subtype': subtype,
'year': year
},
'inferred_filters': inferred_filters if auto_infer_filters else {},
'auto_inference_enabled': auto_infer_filters,
'qdrant_filter_applied': qdrant_filter is not None,
'filter_summary': filter_summary if 'filter_summary' in locals() else None
},
'reranker_model': self._get_reranker_model_name() if use_reranking else None,
'reranker_applied': use_reranking,
'reranking_info': {
'model': self._get_reranker_model_name(),
'applied': use_reranking,
'top_k': len(context_docs) if context_docs else 0,
# 'original_documents': [
# {
# 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
# 'metadata': doc.metadata,
# 'score': getattr(doc, 'score', getattr(doc, 'original_score', 0.0))
# } for doc in context_docs
# ] if use_reranking else None,
'reranked_documents': [
{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': doc.metadata,
'score': doc.metadata.get('original_score', getattr(doc, 'score', 0.0)),
'original_rank': doc.metadata.get('original_rank', None),
'final_rank': doc.metadata.get('final_rank', None),
'reranked_score': doc.metadata.get('reranked_score', None)
} for doc in context_docs
] if use_reranking else None
}
},
query=query
)
return result
except Exception as e:
print(f"Error in pipeline run: {e}")
return PipelineResult(
answer=f"Error processing query: {e}",
sources=[],
execution_time=0.0,
metadata={'error': str(e)},
query=query
)
def get_system_status(self) -> Dict[str, Any]:
"""
Get system status information.
Returns:
Dictionary with system status
"""
status = {
"config_loaded": bool(self.config),
"chunks_loaded": bool(self.chunks),
"vectorstore_connected": bool(
self.vectorstore_manager and self.vectorstore_manager.get_vectorstore()
),
"components_initialized": bool(
self.context_retriever and self.report_service
),
}
if self.chunks:
status["num_chunks"] = len(self.chunks)
if self.report_service:
status["available_sources"] = self.report_service.get_available_sources()
status["available_reports"] = len(
self.report_service.get_available_reports()
)
status["overall_status"] = (
"ready"
if all(
[
status["config_loaded"],
status["chunks_loaded"],
status["vectorstore_connected"],
status["components_initialized"],
]
)
else "not_ready"
)
return status
def get_available_llm_providers(self) -> List[str]:
"""Get list of available LLM providers."""
providers = []
reader_config = self.config.get("reader", {})
for provider in [
"MISTRAL",
"OPENAI",
"OLLAMA",
"INF_PROVIDERS",
"NVIDIA",
"DEDICATED",
"OPENROUTER",
]:
if provider in reader_config:
providers.append(provider.lower())
return providers