sadickam's picture
Initial commit for HF Space
3326079
"""Pydantic models and protocols for retrieval operations.
This module defines the interfaces and data models used throughout the retrieval
pipeline. These models provide type-safe, validated representations of retrieval
queries, results, and the retriever interface protocol.
Models:
- RetrievalResult: Represents a single retrieved chunk with relevance score
- RetrievalQuery: Represents a search query with configuration options
Protocols:
- Retriever: Protocol defining the retriever interface for dependency injection
Functions:
- normalize_text: Utility function for cleaning and normalizing text
Design Principles:
- Protocol-based interface enables dependency injection and testing
- Pydantic v2 for validation and serialization
- Frozen models for immutability and hashability
- Automatic text normalization on model creation
Lazy Loading:
No heavy dependencies in this module. Pydantic is a lightweight dependency
that loads quickly.
Example:
-------
>>> from rag_chatbot.retrieval.models import RetrievalQuery, RetrievalResult
>>> query = RetrievalQuery(query="What is PMV?", top_k=5)
>>> result = RetrievalResult(
... chunk_id="ashrae55_001",
... text="The PMV model predicts thermal sensation...",
... score=0.85,
... heading_path=["Chapter 1", "Section 1.1"],
... source="ashrae_55.pdf",
... page=5,
... )
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
)
# =============================================================================
# Type Checking Imports
# =============================================================================
# These imports are only processed by type checkers (mypy, pyright) and IDEs.
# They enable proper type hints without runtime overhead.
# =============================================================================
if TYPE_CHECKING:
pass # No type-only imports needed currently
# =============================================================================
# Module Exports
# =============================================================================
__all__: list[str] = [
"RetrievalResult",
"RetrievalQuery",
"Retriever",
"normalize_text",
]
# =============================================================================
# Constants
# =============================================================================
# Regex pattern for multiple consecutive whitespace characters (spaces/tabs)
# Used to collapse multiple spaces into a single space
_MULTI_SPACE_PATTERN: re.Pattern[str] = re.compile(r"[ \t]+")
# Regex pattern to detect sentence endings followed by lowercase letter
# Used to fix capitalization after sentence-ending punctuation
# Matches: period/exclamation/question mark + space(s) + lowercase letter
_SENTENCE_END_PATTERN: re.Pattern[str] = re.compile(r"([.!?])\s+([a-z])")
# =============================================================================
# Text Normalization
# =============================================================================
def normalize_text(text: str) -> str:
"""Normalize text by cleaning whitespace and fixing sentence structure.
This function performs the following normalizations:
1. Strips leading and trailing whitespace
2. Collapses multiple spaces/tabs into a single space
3. Preserves newlines (single newlines are kept as-is)
4. Capitalizes the first letter after sentence-ending punctuation
The function is designed to handle common text issues from PDF extraction
while preserving intentional formatting like paragraph breaks.
Args:
----
text: The text string to normalize.
Returns:
-------
Cleaned text with normalized whitespace and proper sentence structure.
Example:
-------
>>> normalize_text(" Hello world ")
'Hello world'
>>> normalize_text("First sentence. second sentence")
'First sentence. Second sentence'
>>> normalize_text("Multiple spaces here")
'Multiple spaces here'
Note:
----
- Single spaces within words like "wo rd" are preserved as-is.
The text normalizer in chunking.models handles domain-specific
word corrections.
- Newlines are preserved to maintain paragraph structure.
- Empty strings return empty strings (no error raised).
"""
# Handle empty or whitespace-only input
if not text or not text.strip():
return ""
# Step 1: Strip leading and trailing whitespace
result = text.strip()
# Step 2: Collapse multiple spaces/tabs into single space
# Note: This pattern only matches spaces and tabs, not newlines
# This preserves paragraph structure while cleaning inline spacing
result = _MULTI_SPACE_PATTERN.sub(" ", result)
# Step 3: Fix capitalization after sentence-ending punctuation
# This handles cases like "First sentence. second sentence" ->
# "First sentence. Second sentence"
# The lambda capitalizes the captured lowercase letter after punctuation
def _capitalize_after_sentence(match: re.Match[str]) -> str:
"""Capitalize the first letter after sentence-ending punctuation."""
punctuation = match.group(1)
first_char = match.group(2)
return f"{punctuation} {first_char.upper()}"
result = _SENTENCE_END_PATTERN.sub(_capitalize_after_sentence, result)
return result
# =============================================================================
# Protocols
# =============================================================================
@runtime_checkable
class Retriever(Protocol):
"""Protocol for document retrievers.
This protocol defines the interface that all retriever implementations must
follow. It enables dependency injection and makes testing easier by allowing
mock retrievers to be used in place of real ones.
Any class implementing this protocol can be used wherever a retriever is
expected, regardless of the underlying implementation (FAISS, BM25, hybrid,
etc.).
Methods:
-------
retrieve: Search for relevant chunks given a query string.
Example:
-------
>>> class MockRetriever:
... def retrieve(self, query: str, top_k: int = 6) -> list[RetrievalResult]:
... return [
... RetrievalResult(
... chunk_id="test_001",
... text="Mock result",
... score=0.9,
... heading_path=["Test"],
... source="test.pdf",
... page=1,
... )
... ]
>>> retriever: Retriever = MockRetriever()
>>> isinstance(retriever, Retriever)
True
Note:
----
The @runtime_checkable decorator allows isinstance() checks at
runtime, which is useful for validating retriever implementations.
"""
def retrieve(self, query: str, top_k: int = 6) -> list[RetrievalResult]:
"""Retrieve relevant chunks for a given query.
Searches the document store and returns the most relevant chunks
ranked by relevance score in descending order.
Args:
----
query: The search query string. Should be a natural language
question or keyword phrase.
top_k: Maximum number of results to return. Defaults to 6.
Must be a positive integer.
Returns:
-------
List of RetrievalResult objects sorted by score in descending
order (highest relevance first). The list may contain fewer
than top_k results if not enough relevant chunks are found.
Raises:
------
ValueError: If query is empty or top_k is not positive.
RuntimeError: If retrieval fails due to index issues.
Example:
-------
>>> results = retriever.retrieve("What is PMV?", top_k=5)
>>> len(results) <= 5
True
>>> all(0.0 <= r.score <= 1.0 for r in results)
True
"""
...
# =============================================================================
# Data Models
# =============================================================================
class RetrievalResult(BaseModel):
"""Represents a single retrieved chunk with its relevance score.
A RetrievalResult contains all the information needed to present a
search result to the user or to pass context to the LLM. Each result
includes the chunk content, metadata for citation, and a relevance
score indicating how well it matches the query.
The model is frozen (immutable) to ensure thread-safety and hashability.
Attributes:
----------
chunk_id : str
Unique identifier for the chunk within the corpus.
Used for deduplication and logging.
text : str
The text content of the chunk (automatically normalized).
This is the content that will be shown to the user or
passed to the LLM as context.
score : float
Relevance score between 0.0 and 1.0, where 1.0 indicates
perfect relevance. Used for ranking and filtering results.
heading_path : list[str]
Hierarchical path of headings providing context for the chunk.
Example: ["Chapter 1", "Section 1.1", "PMV Model"]
Empty list if no heading hierarchy is available.
source : str
Source document name or path for citation purposes.
Example: "ashrae_55.pdf"
page : int
Page number in the source document (1-indexed).
Used for precise citation and navigation.
Example:
-------
>>> result = RetrievalResult(
... chunk_id="ashrae55_001",
... text="The PMV model predicts thermal sensation...",
... score=0.85,
... heading_path=["Thermal Comfort", "PMV Model"],
... source="ashrae_55.pdf",
... page=5,
... )
>>> result.score
0.85
>>> result.heading_path
['Thermal Comfort', 'PMV Model']
Note:
----
The text field is automatically normalized using normalize_text()
to ensure consistent formatting across all results.
"""
# -------------------------------------------------------------------------
# Model Configuration
# -------------------------------------------------------------------------
model_config = ConfigDict(
# Forbid extra fields to catch typos and ensure data integrity
extra="forbid",
# Make the model immutable for thread-safety and hashability
frozen=True,
# Allow population by field name or alias
populate_by_name=True,
# Validate default values during model creation
validate_default=True,
# Enable JSON schema generation with examples
json_schema_extra={
"examples": [
{
"chunk_id": "ashrae55_001",
"text": "The PMV model predicts thermal sensation based on...",
"score": 0.85,
"heading_path": ["Thermal Comfort", "PMV Model"],
"source": "ashrae_55.pdf",
"page": 5,
},
{
"chunk_id": "iso7730_042",
"text": "The PPD index represents the percentage of...",
"score": 0.72,
"heading_path": ["ISO 7730", "PPD Calculation"],
"source": "iso_7730.pdf",
"page": 12,
},
]
},
)
# -------------------------------------------------------------------------
# Fields
# -------------------------------------------------------------------------
chunk_id: str = Field(
..., # Required field (no default)
min_length=1, # Must not be empty
description="Unique identifier for the chunk within the corpus",
examples=["ashrae55_001", "iso7730_042", "guide_chapter2_015"],
)
text: str = Field(
..., # Required field
min_length=1, # Must not be empty after normalization
description="The text content of the chunk (normalized)",
examples=["The PMV model predicts thermal sensation based on..."],
)
score: float = Field(
..., # Required field
ge=0.0, # Minimum score is 0.0
le=1.0, # Maximum score is 1.0
description="Relevance score between 0.0 and 1.0",
examples=[0.85, 0.72, 0.65],
)
heading_path: list[str] = Field(
default_factory=list,
description="Hierarchical path of headings providing context",
examples=[["Chapter 1", "Section 1.1"], ["Thermal Comfort", "PMV Model"]],
)
source: str = Field(
..., # Required field
min_length=1, # Must not be empty
description="Source document name or path",
examples=["ashrae_55.pdf", "iso_7730.pdf", "pythermalcomfort_guide.pdf"],
)
page: int = Field(
..., # Required field
ge=1, # Pages are 1-indexed
description="Page number in the source document (1-indexed)",
examples=[1, 5, 42],
)
# -------------------------------------------------------------------------
# Validators
# -------------------------------------------------------------------------
@field_validator("text", mode="before")
@classmethod
def _normalize_text_field(cls, value: object) -> str:
"""Normalize the text field using normalize_text function.
This validator is called before standard validation and ensures
that all text content is consistently formatted.
Args:
----
value: The input value to normalize.
Returns:
-------
Normalized string content.
Raises:
------
ValueError: If value is None or empty after normalization.
"""
if value is None:
msg = "text cannot be None"
raise ValueError(msg)
# Convert to string and normalize
text = normalize_text(str(value))
if not text:
msg = "text cannot be empty after normalization"
raise ValueError(msg)
return text
@field_validator("chunk_id", "source", mode="before")
@classmethod
def _strip_string_fields(cls, value: object) -> str:
"""Strip whitespace from string identifier fields.
Args:
----
value: The input value to process.
Returns:
-------
Stripped string value.
Raises:
------
ValueError: If value is None or empty after stripping.
"""
if value is None:
msg = "Field cannot be None"
raise ValueError(msg)
result = str(value).strip()
if not result:
msg = "Field cannot be empty"
raise ValueError(msg)
return result
@field_validator("heading_path", mode="before")
@classmethod
def _ensure_heading_list(cls, value: object) -> list[str]:
"""Ensure heading_path is always a list of strings.
Args:
----
value: The input value to normalize.
Returns:
-------
List of heading strings (may be empty).
"""
if value is None:
return []
if isinstance(value, str):
# Single heading provided as string
stripped = value.strip()
return [stripped] if stripped else []
if isinstance(value, list):
# Filter out empty strings and convert all to strings
return [str(h).strip() for h in value if str(h).strip()]
# Handle other iterables
try:
iterator = iter(value) # type: ignore[call-overload]
return [str(h).strip() for h in iterator if str(h).strip()]
except TypeError:
# Not iterable, wrap in list if non-empty
h_str = str(value).strip()
return [h_str] if h_str else []
class RetrievalQuery(BaseModel):
"""Represents a search query with configuration options.
A RetrievalQuery encapsulates the search parameters for a retrieval
operation. It includes the query string itself along with options
that control how the search is performed.
The model is frozen (immutable) to ensure thread-safety and hashability.
Attributes:
----------
query : str
The search query string (automatically normalized).
Should be a natural language question or keyword phrase.
top_k : int
Number of results to return. Defaults to 6.
Must be a positive integer.
use_reranker : bool
Whether to use a reranker for better result ranking.
Defaults to False. When True, results are reranked using
a cross-encoder model for improved relevance.
Example:
-------
>>> query = RetrievalQuery(query="What is PMV?")
>>> query.query
'What is PMV?'
>>> query.top_k
6
>>> query = RetrievalQuery(query="thermal comfort", top_k=10, use_reranker=True)
>>> query.use_reranker
True
Note:
----
The query field is automatically normalized using normalize_text()
to ensure consistent query formatting.
"""
# -------------------------------------------------------------------------
# Model Configuration
# -------------------------------------------------------------------------
model_config = ConfigDict(
# Forbid extra fields to catch typos and ensure data integrity
extra="forbid",
# Make the model immutable for thread-safety and hashability
frozen=True,
# Allow population by field name or alias
populate_by_name=True,
# Validate default values during model creation
validate_default=True,
# Enable JSON schema generation with examples
json_schema_extra={
"examples": [
{
"query": "What is PMV?",
"top_k": 6,
"use_reranker": False,
},
{
"query": "How to calculate thermal comfort?",
"top_k": 10,
"use_reranker": True,
},
]
},
)
# -------------------------------------------------------------------------
# Fields
# -------------------------------------------------------------------------
query: str = Field(
..., # Required field (no default)
min_length=1, # Must not be empty after normalization
description="The search query string (normalized)",
examples=["What is PMV?", "thermal comfort calculation", "ASHRAE 55 standard"],
)
top_k: int = Field(
default=6,
gt=0, # Must be positive (greater than 0)
description="Number of results to return",
examples=[5, 6, 10, 20],
)
use_reranker: bool = Field(
default=False,
description="Whether to use a reranker for better result ranking",
)
# -------------------------------------------------------------------------
# Validators
# -------------------------------------------------------------------------
@field_validator("query", mode="before")
@classmethod
def _normalize_query_field(cls, value: object) -> str:
"""Normalize the query field using normalize_text function.
This validator is called before standard validation and ensures
that all query strings are consistently formatted.
Args:
----
value: The input value to normalize.
Returns:
-------
Normalized query string.
Raises:
------
ValueError: If value is None or empty after normalization.
"""
if value is None:
msg = "query cannot be None"
raise ValueError(msg)
# Convert to string and normalize
query = normalize_text(str(value))
if not query:
msg = "query cannot be empty after normalization"
raise ValueError(msg)
return query