hrbot / src /rag /retriever.py
Sonu Prasad
updated
8a1c0d1
"""Document retriever for RAG pipeline."""
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
from src.config import settings
from src.knowledge.vector_store import FAISSVectorStore, RetrievalResult
from src.rag.prompts import CONTEXT_CHUNK_TEMPLATE
class RetrievalContext(BaseModel):
"""Context retrieved for report generation."""
policy_results: list[RetrievalResult]
warning_results: list[RetrievalResult]
policy_context_text: str
warnings_context_text: str
has_sufficient_evidence: bool
sources_used: list[str]
class Config:
arbitrary_types_allowed = True
class DocumentRetriever:
"""Retrieves relevant documents for HR report generation.
Separates retrieval into policy documents and warning/coaching
documents to ensure proper context for each report section.
"""
def __init__(self, vector_store: Optional[FAISSVectorStore] = None):
"""Initialize the retriever.
Args:
vector_store: Vector store to search. Creates new one if not provided.
"""
self.vector_store = vector_store or FAISSVectorStore()
# Try to load existing index
if not self.vector_store._is_loaded:
self.vector_store.load()
def _format_results_as_context(self, results: list[RetrievalResult]) -> str:
"""Format retrieval results as context string for the LLM."""
if not results:
return "No relevant documents found."
context_parts = []
for result in results:
formatted = CONTEXT_CHUNK_TEMPLATE.format(
source_file=Path(result.chunk.source_file).name,
section_title=result.chunk.section_title or "General",
score=result.score,
content=result.chunk.content,
)
context_parts.append(formatted)
return "\n".join(context_parts)
def _classify_results(
self, results: list[RetrievalResult]
) -> tuple[list[RetrievalResult], list[RetrievalResult]]:
"""Classify results into policy and warning categories.
Uses simple heuristics based on source filename and content.
"""
policy_results = []
warning_results = []
policy_keywords = ["policy", "handbook", "manual", "guideline", "procedure"]
warning_keywords = ["warning", "coaching", "counseling", "disciplinary", "incident"]
for result in results:
source_lower = Path(result.chunk.source_file).stem.lower()
content_lower = result.chunk.content.lower()
# Check if it's a warning/coaching document
is_warning = any(kw in source_lower for kw in warning_keywords) or any(
kw in content_lower[:200] for kw in warning_keywords
)
# Check if it's a policy document
is_policy = any(kw in source_lower for kw in policy_keywords) or any(
kw in content_lower[:200] for kw in policy_keywords
)
if is_warning:
warning_results.append(result)
elif is_policy:
policy_results.append(result)
else:
# Default to policy if unclear
policy_results.append(result)
return policy_results, warning_results
def retrieve(
self,
employee_name: str,
violation_type: str,
incident_reason: str,
top_k: int = None,
min_score: float = None,
) -> RetrievalContext:
"""Retrieve relevant context for report generation.
Args:
employee_name: Name of the employee.
violation_type: Type of violation (e.g., "Tardiness").
incident_reason: Description of the incident.
top_k: Number of results per query.
min_score: Minimum similarity score.
Returns:
RetrievalContext with categorized results.
"""
top_k = top_k or settings.retrieval_top_k
min_score = min_score or settings.retrieval_min_score
# Build search queries
policy_query = f"{violation_type} policy procedure disciplinary action"
warning_query = f"{employee_name} warning coaching disciplinary {violation_type}"
incident_query = f"{incident_reason} {violation_type}"
# Execute searches
policy_results = self.vector_store.search(policy_query, top_k=top_k, min_score=min_score)
warning_results = self.vector_store.search(warning_query, top_k=top_k, min_score=min_score)
incident_results = self.vector_store.search(
incident_query, top_k=top_k, min_score=min_score
)
# Combine and deduplicate
all_results = {}
for result in policy_results + warning_results + incident_results:
chunk_id = result.chunk.chunk_id
if chunk_id not in all_results or result.score > all_results[chunk_id].score:
all_results[chunk_id] = result
# Re-classify all results
all_results_list = sorted(all_results.values(), key=lambda r: r.score, reverse=True)
policy_classified, warning_classified = self._classify_results(all_results_list)
# Format as context text
policy_context = self._format_results_as_context(policy_classified)
warnings_context = self._format_results_as_context(warning_classified)
# Determine if we have sufficient evidence
has_evidence = len(policy_classified) > 0 or len(warning_classified) > 0
# Collect unique sources
sources = list(
set(
Path(r.chunk.source_file).name
for r in policy_classified + warning_classified
)
)
return RetrievalContext(
policy_results=policy_classified,
warning_results=warning_classified,
policy_context_text=policy_context,
warnings_context_text=warnings_context,
has_sufficient_evidence=has_evidence,
sources_used=sources,
)
def retrieve_for_employee(
self, employee_name: str, top_k: int = 10
) -> list[RetrievalResult]:
"""Retrieve all documents mentioning an employee.
Useful for finding prior warnings and coaching records.
"""
return self.vector_store.search(employee_name, top_k=top_k, min_score=0.2)