SmokeScan / rag /retriever.py
KinetoLabs's picture
Frontend simplification (4→2 tabs) + lazy imports for HF Spaces
78caafb
"""FDAM retriever with priority weighting and reranking.
Implements tiered retrieval:
1. Vector similarity search
2. Priority weighting (primary > reference-threshold > reference-narrative)
3. Optional reranking for production
"""
import logging
import time
from typing import Optional, TYPE_CHECKING
from dataclasses import dataclass
from config.settings import settings
# Type hints only - actual import deferred to __init__
if TYPE_CHECKING:
from .vectorstore import ChromaVectorStore
logger = logging.getLogger(__name__)
@dataclass
class RetrievalResult:
"""A single retrieval result with relevance score."""
chunk_id: str
text: str
source: str
category: str
section: str
priority: str
content_type: str
keywords: list[str]
similarity_score: float # 0-1, higher is better
weighted_score: float # After priority weighting
final_score: float # After reranking (if applied)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"chunk_id": self.chunk_id,
"text": self.text,
"source": self.source,
"category": self.category,
"section": self.section,
"priority": self.priority,
"content_type": self.content_type,
"keywords": self.keywords,
"similarity_score": self.similarity_score,
"weighted_score": self.weighted_score,
"final_score": self.final_score,
}
class MockReranker:
"""Mock reranker for local development.
Simply returns scores based on keyword overlap.
"""
def rerank(
self,
query: str,
documents: list[str],
) -> list[float]:
"""Score documents based on keyword overlap with query.
Args:
query: Query text
documents: List of document texts
Returns:
List of scores (0-1) for each document
"""
query_words = set(query.lower().split())
scores = []
for doc in documents:
doc_words = set(doc.lower().split())
# Jaccard-like overlap score
overlap = len(query_words & doc_words)
total = len(query_words | doc_words)
score = overlap / total if total > 0 else 0.0
scores.append(score)
return scores
class SharedReranker:
"""Reranker that uses the shared model from RealModelStack.
This avoids loading a duplicate reranker model - instead uses the
model already loaded by the pipeline at startup.
"""
def rerank(
self,
query: str,
documents: list[str],
) -> list[float]:
"""Score documents using the shared reranker model.
Args:
query: Query text
documents: List of document texts
Returns:
List of scores (0-1) for each document
"""
from models.loader import get_models
model_stack = get_models()
# Use the shared reranker model (always loaded at startup)
return model_stack.reranker.rerank(query, documents)
def get_reranker():
"""Get appropriate reranker based on settings.
For real models, uses SharedReranker which wraps the
model stack's reranker model (no duplicate loading).
"""
if settings.mock_models:
return MockReranker()
return SharedReranker()
class FDAMRetriever:
"""FDAM-specific retriever with priority weighting.
Priority weights:
- primary: 1.0 (FDAM methodology)
- reference-threshold: 0.9 (Threshold tables)
- reference-narrative: 0.8 (Supporting documentation)
"""
PRIORITY_WEIGHTS = {
"primary": 1.0,
"reference-threshold": 0.9,
"reference-narrative": 0.8,
}
def __init__(
self,
vectorstore: Optional["ChromaVectorStore"] = None,
reranker=None,
use_reranking: bool = True,
):
"""Initialize retriever.
Args:
vectorstore: ChromaDB vector store instance.
If None, creates default instance.
reranker: Reranker instance. If None, uses appropriate default.
use_reranking: Whether to apply reranking step.
"""
if vectorstore is None:
# Lazy import to avoid chromadb dependency at module load
from .vectorstore import ChromaVectorStore
vectorstore = ChromaVectorStore()
self.vectorstore = vectorstore
self.reranker = reranker if reranker is not None else get_reranker()
self.use_reranking = use_reranking
def retrieve(
self,
query: str,
top_k: int = 5,
category_filter: Optional[str] = None,
priority_filter: Optional[str] = None,
include_scores: bool = True,
) -> list[RetrievalResult]:
"""Retrieve relevant chunks for a query.
Args:
query: Query text
top_k: Number of results to return
category_filter: Optional category to filter by
priority_filter: Optional priority to filter by
include_scores: Whether to include score details
Returns:
List of RetrievalResult objects, sorted by final_score descending
"""
start_time = time.time()
logger.debug(f"RAG retrieve: query='{query[:50]}...' top_k={top_k}")
# Build metadata filter
where_filter = None
if category_filter or priority_filter:
where_filter = {}
if category_filter:
where_filter["category"] = category_filter
if priority_filter:
where_filter["priority"] = priority_filter
# Fetch more results than needed for reranking
fetch_k = top_k * 3 if self.use_reranking else top_k
# Query vector store
raw_results = self.vectorstore.query(
query_text=query,
n_results=fetch_k,
where=where_filter,
)
if not raw_results:
logger.debug("RAG retrieve: no results found")
return []
# Convert to RetrievalResult objects with priority weighting
results = []
for r in raw_results:
# Convert distance to similarity (cosine distance: 0 = identical)
similarity = 1.0 - r["distance"]
# Apply priority weight
priority = r["metadata"].get("priority", "reference-narrative")
weight = self.PRIORITY_WEIGHTS.get(priority, 0.8)
weighted_score = similarity * weight
# Parse keywords
keywords_str = r["metadata"].get("keywords", "")
keywords = keywords_str.split(",") if keywords_str else []
results.append(
RetrievalResult(
chunk_id=r["id"],
text=r["document"],
source=r["metadata"].get("source", "unknown"),
category=r["metadata"].get("category", "unknown"),
section=r["metadata"].get("section", "unknown"),
priority=priority,
content_type=r["metadata"].get("content_type", "narrative"),
keywords=keywords,
similarity_score=similarity,
weighted_score=weighted_score,
final_score=weighted_score, # Will be updated by reranking
)
)
# Apply reranking if enabled
if self.use_reranking and results:
logger.debug(f"Applying reranking to {len(results)} results")
documents = [r.text for r in results]
rerank_scores = self.reranker.rerank(query, documents)
# Combine weighted score with rerank score
# Final = 0.6 * weighted + 0.4 * rerank
for i, result in enumerate(results):
rerank_score = rerank_scores[i]
result.final_score = 0.6 * result.weighted_score + 0.4 * rerank_score
# Sort by final score (descending) and take top_k
results.sort(key=lambda x: x.final_score, reverse=True)
final_results = results[:top_k]
# Log retrieval summary
elapsed = time.time() - start_time
if final_results:
top_score = final_results[0].final_score
top_source = final_results[0].source
logger.debug(f"RAG retrieve: {len(final_results)} results in {elapsed:.3f}s, "
f"top_score={top_score:.3f}, top_source={top_source}")
else:
logger.debug(f"RAG retrieve: 0 results in {elapsed:.3f}s")
return final_results
def retrieve_for_context(
self,
query: str,
top_k: int = 5,
) -> str:
"""Retrieve and format chunks as context string for LLM.
Args:
query: Query text
top_k: Number of chunks to include
Returns:
Formatted context string with source citations
"""
results = self.retrieve(query, top_k=top_k)
if not results:
return "No relevant context found."
context_parts = []
for i, r in enumerate(results, 1):
context_parts.append(
f"[{i}] Source: {r.source} | Section: {r.section}\n{r.text}"
)
return "\n\n---\n\n".join(context_parts)
def retrieve_thresholds(
self,
material_type: str,
facility_type: str,
) -> list[RetrievalResult]:
"""Retrieve threshold values for a specific material and facility type.
Convenience method for threshold lookups.
Args:
material_type: Type of material (e.g., "lead", "soot", "char")
facility_type: Facility classification
Returns:
Relevant threshold results
"""
query = f"{material_type} threshold {facility_type} clearance criteria"
return self.retrieve(
query=query,
top_k=3,
category_filter="thresholds",
)
def retrieve_disposition(
self,
zone: str,
condition: str,
material_type: Optional[str] = None,
) -> list[RetrievalResult]:
"""Retrieve disposition guidance for zone/condition combination.
Convenience method for disposition lookups.
Args:
zone: Zone classification (burn-zone, near-field, far-field)
condition: Condition level (background, light, moderate, heavy, structural-damage)
material_type: Optional material type for specific guidance
Returns:
Relevant disposition results
"""
query = f"disposition {zone} {condition}"
if material_type:
query += f" {material_type}"
query += " cleaning recommendation"
return self.retrieve(
query=query,
top_k=5,
priority_filter="primary", # Prefer FDAM methodology
)
def retrieve_cleaning_method(
self,
surface_type: str,
condition: str,
) -> list[RetrievalResult]:
"""Retrieve cleaning method recommendations.
Args:
surface_type: Type of surface (e.g., "drywall", "concrete", "metal")
condition: Condition level
Returns:
Relevant cleaning method results
"""
query = f"cleaning method {surface_type} {condition} procedure hepa"
return self.retrieve(
query=query,
top_k=5,
)