Spaces:
Running
Running
| # DEPENDENCIES | |
| import re | |
| from typing import List | |
| from typing import Dict | |
| from typing import Tuple | |
| from typing import Optional | |
| from collections import defaultdict | |
| from config.models import DocumentChunk | |
| from config.models import ChunkWithScore | |
| from config.logging_config import get_logger | |
| from utils.error_handler import CitationError | |
| from utils.error_handler import handle_errors | |
| # Setup Logging | |
| logger = get_logger(__name__) | |
| class CitationTracker: | |
| """ | |
| Citation tracking and management: Tracks source citations in generated text and provides citation formatting and validation | |
| """ | |
| def __init__(self): | |
| """ | |
| Initialize citation tracker | |
| """ | |
| self.logger = logger | |
| self.citation_pattern = re.compile(r'\[(\d+)\]') | |
| def extract_citations(self, text: str) -> List[int]: | |
| """ | |
| Extract citation numbers from text | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citations | |
| Returns: | |
| -------- | |
| { list } : List of citation numbers found in text | |
| """ | |
| if not text: | |
| return [] | |
| try: | |
| matches = self.citation_pattern.findall(text) | |
| citation_numbers = [int(match) for match in matches] | |
| # Remove duplicates and sort | |
| unique_citations = sorted(set(citation_numbers)) | |
| self.logger.debug(f"Extracted {len(unique_citations)} citations from text") | |
| return unique_citations | |
| except Exception as e: | |
| self.logger.error(f"Citation extraction failed: {repr(e)}") | |
| return [] | |
| def validate_citations(self, text: str, sources: List[ChunkWithScore]) -> Tuple[bool, List[int]]: | |
| """ | |
| Validate that all citations in text reference existing sources | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citations | |
| sources { list } : List of available sources | |
| Returns: | |
| -------- | |
| { Tuple[bool, List[int]] } : (is_valid, invalid_citations) | |
| """ | |
| citation_numbers = self.extract_citations(text = text) | |
| if not citation_numbers: | |
| return True, [] | |
| # Check if all citation numbers are within valid range | |
| max_valid = len(sources) | |
| invalid_citations = [num for num in citation_numbers if (num < 1) or (num > max_valid)] | |
| if invalid_citations: | |
| self.logger.warning(f"Invalid citations found: {invalid_citations}. Valid range: 1-{max_valid}") | |
| return False, invalid_citations | |
| return True, [] | |
| def format_citations(self, sources: List[ChunkWithScore], style: str = "numeric") -> str: | |
| """ | |
| Format citations as reference list | |
| Arguments: | |
| ---------- | |
| sources { list } : List of sources to format | |
| style { str } : Citation style ('numeric', 'verbose') | |
| Returns: | |
| -------- | |
| { str } : Formatted citation text | |
| """ | |
| if not sources: | |
| return "" | |
| try: | |
| citations = list() | |
| for i, source in enumerate(sources, 1): | |
| if (style == "verbose"): | |
| citation = self._format_verbose_citation(source = source, | |
| number = i, | |
| ) | |
| else: | |
| citation = self._format_numeric_citation(source = source, | |
| number = i, | |
| ) | |
| citations.append(citation) | |
| citation_text = "\n".join(citations) | |
| self.logger.debug(f"Formatted {len(citations)} citations in {style} style") | |
| return citation_text | |
| except Exception as e: | |
| self.logger.error(f"Citation formatting failed: {repr(e)}") | |
| return "" | |
| def _format_numeric_citation(self, source: ChunkWithScore, number: int) -> str: | |
| """ | |
| Format citation in numeric style with sanitization | |
| Arguments: | |
| ---------- | |
| source { ChunkWithScore } : Source to format | |
| number { int } : Citation number | |
| Returns: | |
| -------- | |
| { str } : Formatted citation | |
| """ | |
| chunk = source.chunk | |
| parts = [f"[{number}]"] | |
| # Add source information with proper sanitization | |
| if (hasattr(chunk, 'metadata') and chunk.metadata): | |
| if ('filename' in chunk.metadata): | |
| # Sanitize filename more thoroughly | |
| filename = str(chunk.metadata['filename']) | |
| # Remove problematic characters that could break citation parsing: Keep only alphanumeric, spaces, dots, hyphens, underscores | |
| filename = re.sub(r'[^\w\s\.\-]', '_', filename) | |
| # Limit length to prevent overflow | |
| if (len(filename) > 50): | |
| filename = filename[:47] + "..." | |
| parts.append(f"Source: {filename}") | |
| if chunk.page_number: | |
| parts.append(f"Page {chunk.page_number}") | |
| if chunk.section_title: | |
| # Sanitize section title similarly | |
| section = str(chunk.section_title) | |
| section = re.sub(r'[^\w\s\.\-]', '_', section) | |
| if (len(section) > 40): | |
| section = section[:37] + "..." | |
| parts.append(f"Section: {section}") | |
| # Add relevance score if available | |
| if (source.score > 0): | |
| parts.append(f"(Relevance: {source.score:.2f})") | |
| return " ".join(parts) | |
| def _format_verbose_citation(self, source: ChunkWithScore, number: int) -> str: | |
| """ | |
| Format citation in verbose style - SAFER VERSION | |
| Arguments: | |
| ---------- | |
| source { ChunkWithScore } : Source to format | |
| number { int } : Citation number | |
| Returns: | |
| -------- | |
| { str } : Formatted citation | |
| """ | |
| chunk = source.chunk | |
| parts = [f"Citation {number}:"] | |
| # Document information with sanitization | |
| if (hasattr(chunk, 'metadata')): | |
| meta = chunk.metadata | |
| if ('filename' in meta): | |
| filename = str(meta['filename']) | |
| filename = re.sub(r'[^\w\s\.\-]', '_', filename) | |
| if (len(filename) > 50): | |
| filename = filename[:47] + "..." | |
| parts.append(f"Document: {filename}") | |
| if ('title' in meta): | |
| title = str(meta['title']) | |
| title = re.sub(r'[^\w\s\.\-]', '_', title) | |
| if (len(title) > 60): | |
| title = title[:57] + "..." | |
| parts.append(f"Title: {title}") | |
| if ('author' in meta): | |
| author = str(meta['author']) | |
| author = re.sub(r'[^\w\s\.\-]', '_', author) | |
| if (len(author) > 40): | |
| author = author[:37] + "..." | |
| parts.append(f"Author: {author}") | |
| # Location information | |
| location_parts = list() | |
| if chunk.page_number: | |
| location_parts.append(f"page {chunk.page_number}") | |
| if chunk.section_title: | |
| section = str(chunk.section_title) | |
| section = re.sub(r'[^\w\s\.\-]', '_', section) | |
| if (len(section) > 40): | |
| section = section[:37] + "..." | |
| location_parts.append(f"section '{section}'") | |
| if location_parts: | |
| parts.append("(" + ", ".join(location_parts) + ")") | |
| # Relevance information | |
| if (source.score > 0): | |
| parts.append(f"[Relevance score: {source.score:.3f}]") | |
| return " ".join(parts) | |
| def generate_citation_map(self, sources: List[ChunkWithScore]) -> Dict[int, Dict]: | |
| """ | |
| Generate mapping from citation numbers to source details | |
| Arguments: | |
| ---------- | |
| sources { list } : List of sources | |
| Returns: | |
| -------- | |
| { dict } : Dictionary mapping citation numbers to source details | |
| """ | |
| citation_map = dict() | |
| for i, source in enumerate(sources, 1): | |
| chunk = source.chunk | |
| citation_map[i] = {'chunk_id' : chunk.chunk_id, | |
| 'document_id' : chunk.document_id, | |
| 'score' : source.score, | |
| 'text_preview' : chunk.text[:200] + "..." if (len(chunk.text) > 200) else chunk.text, | |
| 'metadata' : getattr(chunk, 'metadata', {}), | |
| 'page_number' : chunk.page_number, | |
| 'section_title' : chunk.section_title, | |
| } | |
| return citation_map | |
| def replace_citation_markers(self, text: str, citation_map: Dict[int, str]) -> str: | |
| """ | |
| Replace citation markers with formatted citations - FIXED | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citation markers | |
| citation_map { dict } : Mapping of citation numbers to formatted strings | |
| Returns: | |
| -------- | |
| { str } : Text with replaced citations | |
| """ | |
| def replacement(match): | |
| try: | |
| citation_num = int(match.group(1)) | |
| # Get replacement text and sanitize it | |
| replacement_text = citation_map.get(citation_num, match.group(0)) | |
| return str(replacement_text) | |
| except (ValueError, IndexError): | |
| # Return original match if parsing fails | |
| return match.group(0) | |
| try: | |
| return self.citation_pattern.sub(replacement, text) | |
| except Exception as e: | |
| self.logger.error(f"Citation replacement failed: {repr(e)}") | |
| # Return original text on error | |
| return text | |
| def get_citation_statistics(self, text: str, sources: List[ChunkWithScore]) -> Dict: | |
| """ | |
| Get statistics about citations in text | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citations | |
| sources { list } : List of sources | |
| Returns: | |
| -------- | |
| { dict } : Citation statistics | |
| """ | |
| citation_numbers = self.extract_citations(text = text) | |
| if not citation_numbers: | |
| return {"total_citations": 0} | |
| # Calculate citation distribution | |
| citation_counts = defaultdict(int) | |
| for num in citation_numbers: | |
| if 1 <= num <= len(sources): | |
| source = sources[num - 1] | |
| doc_id = source.chunk.document_id | |
| citation_counts[doc_id] += 1 | |
| return {"total_citations" : len(citation_numbers), | |
| "unique_citations" : len(set(citation_numbers)), | |
| "citation_distribution": dict(citation_counts), | |
| "citations_per_source" : {i: citation_numbers.count(i) for i in set(citation_numbers)}, | |
| } | |
| def ensure_citation_consistency(self, text: str, sources: List[ChunkWithScore]) -> str: | |
| """ | |
| Ensure citation numbers are consistent and sequential | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citations | |
| sources { list } : List of sources | |
| Returns: | |
| -------- | |
| { str } : Text with consistent citations | |
| """ | |
| is_valid, invalid_citations = self.validate_citations(text, sources) | |
| if not is_valid: | |
| self.logger.warning("Invalid citations found, attempting to fix consistency") | |
| # Extract current citations and create mapping | |
| current_citations = self.extract_citations(text = text) | |
| if not current_citations: | |
| return text | |
| # Create mapping from old to new citation numbers | |
| citation_mapping = dict() | |
| for i, old_num in enumerate(sorted(set(current_citations)), 1): | |
| if (old_num <= len(sources)): | |
| citation_mapping[old_num] = i | |
| # Replace citations in text | |
| def consistent_replacement(match): | |
| old_num = int(match.group(1)) | |
| new_num = citation_mapping.get(old_num, old_num) | |
| return f"[{new_num}]" | |
| fixed_text = self.citation_pattern.sub(consistent_replacement, text) | |
| self.logger.info(f"Fixed citation consistency: {current_citations} -> {list(citation_mapping.values())}") | |
| return fixed_text | |
| return text | |
| # Global citation tracker instance | |
| _citation_tracker = None | |
| def get_citation_tracker() -> CitationTracker: | |
| """ | |
| Get global citation tracker instance (singleton) | |
| Returns: | |
| -------- | |
| { CitationTracker } : CitationTracker instance | |
| """ | |
| global _citation_tracker | |
| if _citation_tracker is None: | |
| _citation_tracker = CitationTracker() | |
| return _citation_tracker | |
| def extract_and_validate_citations(text: str, sources: List[ChunkWithScore]) -> Tuple[List[int], bool]: | |
| """ | |
| Convenience function for citation extraction and validation | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citations | |
| sources { list } : List of sources | |
| Returns: | |
| -------- | |
| { Tuple[List[int], bool] } : (citation_numbers, is_valid) | |
| """ | |
| tracker = get_citation_tracker() | |
| citations = tracker.extract_citations(text = text) | |
| is_valid, _ = tracker.validate_citations(text = text, | |
| sources = sources, | |
| ) | |
| return citations, is_valid |