Spaces:
Sleeping
Sleeping
| """Document retriever for RAG pipeline.""" | |
| from pathlib import Path | |
| from typing import Optional | |
| from pydantic import BaseModel | |
| from src.config import settings | |
| from src.knowledge.vector_store import FAISSVectorStore, RetrievalResult | |
| from src.rag.prompts import CONTEXT_CHUNK_TEMPLATE | |
| class RetrievalContext(BaseModel): | |
| """Context retrieved for report generation.""" | |
| policy_results: list[RetrievalResult] | |
| warning_results: list[RetrievalResult] | |
| policy_context_text: str | |
| warnings_context_text: str | |
| has_sufficient_evidence: bool | |
| sources_used: list[str] | |
| class Config: | |
| arbitrary_types_allowed = True | |
| class DocumentRetriever: | |
| """Retrieves relevant documents for HR report generation. | |
| Separates retrieval into policy documents and warning/coaching | |
| documents to ensure proper context for each report section. | |
| """ | |
| def __init__(self, vector_store: Optional[FAISSVectorStore] = None): | |
| """Initialize the retriever. | |
| Args: | |
| vector_store: Vector store to search. Creates new one if not provided. | |
| """ | |
| self.vector_store = vector_store or FAISSVectorStore() | |
| # Try to load existing index | |
| if not self.vector_store._is_loaded: | |
| self.vector_store.load() | |
| def _format_results_as_context(self, results: list[RetrievalResult]) -> str: | |
| """Format retrieval results as context string for the LLM.""" | |
| if not results: | |
| return "No relevant documents found." | |
| context_parts = [] | |
| for result in results: | |
| formatted = CONTEXT_CHUNK_TEMPLATE.format( | |
| source_file=Path(result.chunk.source_file).name, | |
| section_title=result.chunk.section_title or "General", | |
| score=result.score, | |
| content=result.chunk.content, | |
| ) | |
| context_parts.append(formatted) | |
| return "\n".join(context_parts) | |
| def _classify_results( | |
| self, results: list[RetrievalResult] | |
| ) -> tuple[list[RetrievalResult], list[RetrievalResult]]: | |
| """Classify results into policy and warning categories. | |
| Uses simple heuristics based on source filename and content. | |
| """ | |
| policy_results = [] | |
| warning_results = [] | |
| policy_keywords = ["policy", "handbook", "manual", "guideline", "procedure"] | |
| warning_keywords = ["warning", "coaching", "counseling", "disciplinary", "incident"] | |
| for result in results: | |
| source_lower = Path(result.chunk.source_file).stem.lower() | |
| content_lower = result.chunk.content.lower() | |
| # Check if it's a warning/coaching document | |
| is_warning = any(kw in source_lower for kw in warning_keywords) or any( | |
| kw in content_lower[:200] for kw in warning_keywords | |
| ) | |
| # Check if it's a policy document | |
| is_policy = any(kw in source_lower for kw in policy_keywords) or any( | |
| kw in content_lower[:200] for kw in policy_keywords | |
| ) | |
| if is_warning: | |
| warning_results.append(result) | |
| elif is_policy: | |
| policy_results.append(result) | |
| else: | |
| # Default to policy if unclear | |
| policy_results.append(result) | |
| return policy_results, warning_results | |
| def retrieve( | |
| self, | |
| employee_name: str, | |
| violation_type: str, | |
| incident_reason: str, | |
| top_k: int = None, | |
| min_score: float = None, | |
| ) -> RetrievalContext: | |
| """Retrieve relevant context for report generation. | |
| Args: | |
| employee_name: Name of the employee. | |
| violation_type: Type of violation (e.g., "Tardiness"). | |
| incident_reason: Description of the incident. | |
| top_k: Number of results per query. | |
| min_score: Minimum similarity score. | |
| Returns: | |
| RetrievalContext with categorized results. | |
| """ | |
| top_k = top_k or settings.retrieval_top_k | |
| min_score = min_score or settings.retrieval_min_score | |
| # Build search queries | |
| policy_query = f"{violation_type} policy procedure disciplinary action" | |
| warning_query = f"{employee_name} warning coaching disciplinary {violation_type}" | |
| incident_query = f"{incident_reason} {violation_type}" | |
| # Execute searches | |
| policy_results = self.vector_store.search(policy_query, top_k=top_k, min_score=min_score) | |
| warning_results = self.vector_store.search(warning_query, top_k=top_k, min_score=min_score) | |
| incident_results = self.vector_store.search( | |
| incident_query, top_k=top_k, min_score=min_score | |
| ) | |
| # Combine and deduplicate | |
| all_results = {} | |
| for result in policy_results + warning_results + incident_results: | |
| chunk_id = result.chunk.chunk_id | |
| if chunk_id not in all_results or result.score > all_results[chunk_id].score: | |
| all_results[chunk_id] = result | |
| # Re-classify all results | |
| all_results_list = sorted(all_results.values(), key=lambda r: r.score, reverse=True) | |
| policy_classified, warning_classified = self._classify_results(all_results_list) | |
| # Format as context text | |
| policy_context = self._format_results_as_context(policy_classified) | |
| warnings_context = self._format_results_as_context(warning_classified) | |
| # Determine if we have sufficient evidence | |
| has_evidence = len(policy_classified) > 0 or len(warning_classified) > 0 | |
| # Collect unique sources | |
| sources = list( | |
| set( | |
| Path(r.chunk.source_file).name | |
| for r in policy_classified + warning_classified | |
| ) | |
| ) | |
| return RetrievalContext( | |
| policy_results=policy_classified, | |
| warning_results=warning_classified, | |
| policy_context_text=policy_context, | |
| warnings_context_text=warnings_context, | |
| has_sufficient_evidence=has_evidence, | |
| sources_used=sources, | |
| ) | |
| def retrieve_for_employee( | |
| self, employee_name: str, top_k: int = 10 | |
| ) -> list[RetrievalResult]: | |
| """Retrieve all documents mentioning an employee. | |
| Useful for finding prior warnings and coaching records. | |
| """ | |
| return self.vector_store.search(employee_name, top_k=top_k, min_score=0.2) | |