Spaces:
Running
Running
| """Pydantic models and protocols for retrieval operations. | |
| This module defines the interfaces and data models used throughout the retrieval | |
| pipeline. These models provide type-safe, validated representations of retrieval | |
| queries, results, and the retriever interface protocol. | |
| Models: | |
| - RetrievalResult: Represents a single retrieved chunk with relevance score | |
| - RetrievalQuery: Represents a search query with configuration options | |
| Protocols: | |
| - Retriever: Protocol defining the retriever interface for dependency injection | |
| Functions: | |
| - normalize_text: Utility function for cleaning and normalizing text | |
| Design Principles: | |
| - Protocol-based interface enables dependency injection and testing | |
| - Pydantic v2 for validation and serialization | |
| - Frozen models for immutability and hashability | |
| - Automatic text normalization on model creation | |
| Lazy Loading: | |
| No heavy dependencies in this module. Pydantic is a lightweight dependency | |
| that loads quickly. | |
| Example: | |
| ------- | |
| >>> from rag_chatbot.retrieval.models import RetrievalQuery, RetrievalResult | |
| >>> query = RetrievalQuery(query="What is PMV?", top_k=5) | |
| >>> result = RetrievalResult( | |
| ... chunk_id="ashrae55_001", | |
| ... text="The PMV model predicts thermal sensation...", | |
| ... score=0.85, | |
| ... heading_path=["Chapter 1", "Section 1.1"], | |
| ... source="ashrae_55.pdf", | |
| ... page=5, | |
| ... ) | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from typing import TYPE_CHECKING, Protocol, runtime_checkable | |
| from pydantic import ( | |
| BaseModel, | |
| ConfigDict, | |
| Field, | |
| field_validator, | |
| ) | |
| # ============================================================================= | |
| # Type Checking Imports | |
| # ============================================================================= | |
| # These imports are only processed by type checkers (mypy, pyright) and IDEs. | |
| # They enable proper type hints without runtime overhead. | |
| # ============================================================================= | |
| if TYPE_CHECKING: | |
| pass # No type-only imports needed currently | |
| # ============================================================================= | |
| # Module Exports | |
| # ============================================================================= | |
| __all__: list[str] = [ | |
| "RetrievalResult", | |
| "RetrievalQuery", | |
| "Retriever", | |
| "normalize_text", | |
| ] | |
| # ============================================================================= | |
| # Constants | |
| # ============================================================================= | |
| # Regex pattern for multiple consecutive whitespace characters (spaces/tabs) | |
| # Used to collapse multiple spaces into a single space | |
| _MULTI_SPACE_PATTERN: re.Pattern[str] = re.compile(r"[ \t]+") | |
| # Regex pattern to detect sentence endings followed by lowercase letter | |
| # Used to fix capitalization after sentence-ending punctuation | |
| # Matches: period/exclamation/question mark + space(s) + lowercase letter | |
| _SENTENCE_END_PATTERN: re.Pattern[str] = re.compile(r"([.!?])\s+([a-z])") | |
| # ============================================================================= | |
| # Text Normalization | |
| # ============================================================================= | |
| def normalize_text(text: str) -> str: | |
| """Normalize text by cleaning whitespace and fixing sentence structure. | |
| This function performs the following normalizations: | |
| 1. Strips leading and trailing whitespace | |
| 2. Collapses multiple spaces/tabs into a single space | |
| 3. Preserves newlines (single newlines are kept as-is) | |
| 4. Capitalizes the first letter after sentence-ending punctuation | |
| The function is designed to handle common text issues from PDF extraction | |
| while preserving intentional formatting like paragraph breaks. | |
| Args: | |
| ---- | |
| text: The text string to normalize. | |
| Returns: | |
| ------- | |
| Cleaned text with normalized whitespace and proper sentence structure. | |
| Example: | |
| ------- | |
| >>> normalize_text(" Hello world ") | |
| 'Hello world' | |
| >>> normalize_text("First sentence. second sentence") | |
| 'First sentence. Second sentence' | |
| >>> normalize_text("Multiple spaces here") | |
| 'Multiple spaces here' | |
| Note: | |
| ---- | |
| - Single spaces within words like "wo rd" are preserved as-is. | |
| The text normalizer in chunking.models handles domain-specific | |
| word corrections. | |
| - Newlines are preserved to maintain paragraph structure. | |
| - Empty strings return empty strings (no error raised). | |
| """ | |
| # Handle empty or whitespace-only input | |
| if not text or not text.strip(): | |
| return "" | |
| # Step 1: Strip leading and trailing whitespace | |
| result = text.strip() | |
| # Step 2: Collapse multiple spaces/tabs into single space | |
| # Note: This pattern only matches spaces and tabs, not newlines | |
| # This preserves paragraph structure while cleaning inline spacing | |
| result = _MULTI_SPACE_PATTERN.sub(" ", result) | |
| # Step 3: Fix capitalization after sentence-ending punctuation | |
| # This handles cases like "First sentence. second sentence" -> | |
| # "First sentence. Second sentence" | |
| # The lambda capitalizes the captured lowercase letter after punctuation | |
| def _capitalize_after_sentence(match: re.Match[str]) -> str: | |
| """Capitalize the first letter after sentence-ending punctuation.""" | |
| punctuation = match.group(1) | |
| first_char = match.group(2) | |
| return f"{punctuation} {first_char.upper()}" | |
| result = _SENTENCE_END_PATTERN.sub(_capitalize_after_sentence, result) | |
| return result | |
| # ============================================================================= | |
| # Protocols | |
| # ============================================================================= | |
| class Retriever(Protocol): | |
| """Protocol for document retrievers. | |
| This protocol defines the interface that all retriever implementations must | |
| follow. It enables dependency injection and makes testing easier by allowing | |
| mock retrievers to be used in place of real ones. | |
| Any class implementing this protocol can be used wherever a retriever is | |
| expected, regardless of the underlying implementation (FAISS, BM25, hybrid, | |
| etc.). | |
| Methods: | |
| ------- | |
| retrieve: Search for relevant chunks given a query string. | |
| Example: | |
| ------- | |
| >>> class MockRetriever: | |
| ... def retrieve(self, query: str, top_k: int = 6) -> list[RetrievalResult]: | |
| ... return [ | |
| ... RetrievalResult( | |
| ... chunk_id="test_001", | |
| ... text="Mock result", | |
| ... score=0.9, | |
| ... heading_path=["Test"], | |
| ... source="test.pdf", | |
| ... page=1, | |
| ... ) | |
| ... ] | |
| >>> retriever: Retriever = MockRetriever() | |
| >>> isinstance(retriever, Retriever) | |
| True | |
| Note: | |
| ---- | |
| The @runtime_checkable decorator allows isinstance() checks at | |
| runtime, which is useful for validating retriever implementations. | |
| """ | |
| def retrieve(self, query: str, top_k: int = 6) -> list[RetrievalResult]: | |
| """Retrieve relevant chunks for a given query. | |
| Searches the document store and returns the most relevant chunks | |
| ranked by relevance score in descending order. | |
| Args: | |
| ---- | |
| query: The search query string. Should be a natural language | |
| question or keyword phrase. | |
| top_k: Maximum number of results to return. Defaults to 6. | |
| Must be a positive integer. | |
| Returns: | |
| ------- | |
| List of RetrievalResult objects sorted by score in descending | |
| order (highest relevance first). The list may contain fewer | |
| than top_k results if not enough relevant chunks are found. | |
| Raises: | |
| ------ | |
| ValueError: If query is empty or top_k is not positive. | |
| RuntimeError: If retrieval fails due to index issues. | |
| Example: | |
| ------- | |
| >>> results = retriever.retrieve("What is PMV?", top_k=5) | |
| >>> len(results) <= 5 | |
| True | |
| >>> all(0.0 <= r.score <= 1.0 for r in results) | |
| True | |
| """ | |
| ... | |
| # ============================================================================= | |
| # Data Models | |
| # ============================================================================= | |
| class RetrievalResult(BaseModel): | |
| """Represents a single retrieved chunk with its relevance score. | |
| A RetrievalResult contains all the information needed to present a | |
| search result to the user or to pass context to the LLM. Each result | |
| includes the chunk content, metadata for citation, and a relevance | |
| score indicating how well it matches the query. | |
| The model is frozen (immutable) to ensure thread-safety and hashability. | |
| Attributes: | |
| ---------- | |
| chunk_id : str | |
| Unique identifier for the chunk within the corpus. | |
| Used for deduplication and logging. | |
| text : str | |
| The text content of the chunk (automatically normalized). | |
| This is the content that will be shown to the user or | |
| passed to the LLM as context. | |
| score : float | |
| Relevance score between 0.0 and 1.0, where 1.0 indicates | |
| perfect relevance. Used for ranking and filtering results. | |
| heading_path : list[str] | |
| Hierarchical path of headings providing context for the chunk. | |
| Example: ["Chapter 1", "Section 1.1", "PMV Model"] | |
| Empty list if no heading hierarchy is available. | |
| source : str | |
| Source document name or path for citation purposes. | |
| Example: "ashrae_55.pdf" | |
| page : int | |
| Page number in the source document (1-indexed). | |
| Used for precise citation and navigation. | |
| Example: | |
| ------- | |
| >>> result = RetrievalResult( | |
| ... chunk_id="ashrae55_001", | |
| ... text="The PMV model predicts thermal sensation...", | |
| ... score=0.85, | |
| ... heading_path=["Thermal Comfort", "PMV Model"], | |
| ... source="ashrae_55.pdf", | |
| ... page=5, | |
| ... ) | |
| >>> result.score | |
| 0.85 | |
| >>> result.heading_path | |
| ['Thermal Comfort', 'PMV Model'] | |
| Note: | |
| ---- | |
| The text field is automatically normalized using normalize_text() | |
| to ensure consistent formatting across all results. | |
| """ | |
| # ------------------------------------------------------------------------- | |
| # Model Configuration | |
| # ------------------------------------------------------------------------- | |
| model_config = ConfigDict( | |
| # Forbid extra fields to catch typos and ensure data integrity | |
| extra="forbid", | |
| # Make the model immutable for thread-safety and hashability | |
| frozen=True, | |
| # Allow population by field name or alias | |
| populate_by_name=True, | |
| # Validate default values during model creation | |
| validate_default=True, | |
| # Enable JSON schema generation with examples | |
| json_schema_extra={ | |
| "examples": [ | |
| { | |
| "chunk_id": "ashrae55_001", | |
| "text": "The PMV model predicts thermal sensation based on...", | |
| "score": 0.85, | |
| "heading_path": ["Thermal Comfort", "PMV Model"], | |
| "source": "ashrae_55.pdf", | |
| "page": 5, | |
| }, | |
| { | |
| "chunk_id": "iso7730_042", | |
| "text": "The PPD index represents the percentage of...", | |
| "score": 0.72, | |
| "heading_path": ["ISO 7730", "PPD Calculation"], | |
| "source": "iso_7730.pdf", | |
| "page": 12, | |
| }, | |
| ] | |
| }, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Fields | |
| # ------------------------------------------------------------------------- | |
| chunk_id: str = Field( | |
| ..., # Required field (no default) | |
| min_length=1, # Must not be empty | |
| description="Unique identifier for the chunk within the corpus", | |
| examples=["ashrae55_001", "iso7730_042", "guide_chapter2_015"], | |
| ) | |
| text: str = Field( | |
| ..., # Required field | |
| min_length=1, # Must not be empty after normalization | |
| description="The text content of the chunk (normalized)", | |
| examples=["The PMV model predicts thermal sensation based on..."], | |
| ) | |
| score: float = Field( | |
| ..., # Required field | |
| ge=0.0, # Minimum score is 0.0 | |
| le=1.0, # Maximum score is 1.0 | |
| description="Relevance score between 0.0 and 1.0", | |
| examples=[0.85, 0.72, 0.65], | |
| ) | |
| heading_path: list[str] = Field( | |
| default_factory=list, | |
| description="Hierarchical path of headings providing context", | |
| examples=[["Chapter 1", "Section 1.1"], ["Thermal Comfort", "PMV Model"]], | |
| ) | |
| source: str = Field( | |
| ..., # Required field | |
| min_length=1, # Must not be empty | |
| description="Source document name or path", | |
| examples=["ashrae_55.pdf", "iso_7730.pdf", "pythermalcomfort_guide.pdf"], | |
| ) | |
| page: int = Field( | |
| ..., # Required field | |
| ge=1, # Pages are 1-indexed | |
| description="Page number in the source document (1-indexed)", | |
| examples=[1, 5, 42], | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Validators | |
| # ------------------------------------------------------------------------- | |
| def _normalize_text_field(cls, value: object) -> str: | |
| """Normalize the text field using normalize_text function. | |
| This validator is called before standard validation and ensures | |
| that all text content is consistently formatted. | |
| Args: | |
| ---- | |
| value: The input value to normalize. | |
| Returns: | |
| ------- | |
| Normalized string content. | |
| Raises: | |
| ------ | |
| ValueError: If value is None or empty after normalization. | |
| """ | |
| if value is None: | |
| msg = "text cannot be None" | |
| raise ValueError(msg) | |
| # Convert to string and normalize | |
| text = normalize_text(str(value)) | |
| if not text: | |
| msg = "text cannot be empty after normalization" | |
| raise ValueError(msg) | |
| return text | |
| def _strip_string_fields(cls, value: object) -> str: | |
| """Strip whitespace from string identifier fields. | |
| Args: | |
| ---- | |
| value: The input value to process. | |
| Returns: | |
| ------- | |
| Stripped string value. | |
| Raises: | |
| ------ | |
| ValueError: If value is None or empty after stripping. | |
| """ | |
| if value is None: | |
| msg = "Field cannot be None" | |
| raise ValueError(msg) | |
| result = str(value).strip() | |
| if not result: | |
| msg = "Field cannot be empty" | |
| raise ValueError(msg) | |
| return result | |
| def _ensure_heading_list(cls, value: object) -> list[str]: | |
| """Ensure heading_path is always a list of strings. | |
| Args: | |
| ---- | |
| value: The input value to normalize. | |
| Returns: | |
| ------- | |
| List of heading strings (may be empty). | |
| """ | |
| if value is None: | |
| return [] | |
| if isinstance(value, str): | |
| # Single heading provided as string | |
| stripped = value.strip() | |
| return [stripped] if stripped else [] | |
| if isinstance(value, list): | |
| # Filter out empty strings and convert all to strings | |
| return [str(h).strip() for h in value if str(h).strip()] | |
| # Handle other iterables | |
| try: | |
| iterator = iter(value) # type: ignore[call-overload] | |
| return [str(h).strip() for h in iterator if str(h).strip()] | |
| except TypeError: | |
| # Not iterable, wrap in list if non-empty | |
| h_str = str(value).strip() | |
| return [h_str] if h_str else [] | |
| class RetrievalQuery(BaseModel): | |
| """Represents a search query with configuration options. | |
| A RetrievalQuery encapsulates the search parameters for a retrieval | |
| operation. It includes the query string itself along with options | |
| that control how the search is performed. | |
| The model is frozen (immutable) to ensure thread-safety and hashability. | |
| Attributes: | |
| ---------- | |
| query : str | |
| The search query string (automatically normalized). | |
| Should be a natural language question or keyword phrase. | |
| top_k : int | |
| Number of results to return. Defaults to 6. | |
| Must be a positive integer. | |
| use_reranker : bool | |
| Whether to use a reranker for better result ranking. | |
| Defaults to False. When True, results are reranked using | |
| a cross-encoder model for improved relevance. | |
| Example: | |
| ------- | |
| >>> query = RetrievalQuery(query="What is PMV?") | |
| >>> query.query | |
| 'What is PMV?' | |
| >>> query.top_k | |
| 6 | |
| >>> query = RetrievalQuery(query="thermal comfort", top_k=10, use_reranker=True) | |
| >>> query.use_reranker | |
| True | |
| Note: | |
| ---- | |
| The query field is automatically normalized using normalize_text() | |
| to ensure consistent query formatting. | |
| """ | |
| # ------------------------------------------------------------------------- | |
| # Model Configuration | |
| # ------------------------------------------------------------------------- | |
| model_config = ConfigDict( | |
| # Forbid extra fields to catch typos and ensure data integrity | |
| extra="forbid", | |
| # Make the model immutable for thread-safety and hashability | |
| frozen=True, | |
| # Allow population by field name or alias | |
| populate_by_name=True, | |
| # Validate default values during model creation | |
| validate_default=True, | |
| # Enable JSON schema generation with examples | |
| json_schema_extra={ | |
| "examples": [ | |
| { | |
| "query": "What is PMV?", | |
| "top_k": 6, | |
| "use_reranker": False, | |
| }, | |
| { | |
| "query": "How to calculate thermal comfort?", | |
| "top_k": 10, | |
| "use_reranker": True, | |
| }, | |
| ] | |
| }, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Fields | |
| # ------------------------------------------------------------------------- | |
| query: str = Field( | |
| ..., # Required field (no default) | |
| min_length=1, # Must not be empty after normalization | |
| description="The search query string (normalized)", | |
| examples=["What is PMV?", "thermal comfort calculation", "ASHRAE 55 standard"], | |
| ) | |
| top_k: int = Field( | |
| default=6, | |
| gt=0, # Must be positive (greater than 0) | |
| description="Number of results to return", | |
| examples=[5, 6, 10, 20], | |
| ) | |
| use_reranker: bool = Field( | |
| default=False, | |
| description="Whether to use a reranker for better result ranking", | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Validators | |
| # ------------------------------------------------------------------------- | |
| def _normalize_query_field(cls, value: object) -> str: | |
| """Normalize the query field using normalize_text function. | |
| This validator is called before standard validation and ensures | |
| that all query strings are consistently formatted. | |
| Args: | |
| ---- | |
| value: The input value to normalize. | |
| Returns: | |
| ------- | |
| Normalized query string. | |
| Raises: | |
| ------ | |
| ValueError: If value is None or empty after normalization. | |
| """ | |
| if value is None: | |
| msg = "query cannot be None" | |
| raise ValueError(msg) | |
| # Convert to string and normalize | |
| query = normalize_text(str(value)) | |
| if not query: | |
| msg = "query cannot be empty after normalization" | |
| raise ValueError(msg) | |
| return query | |