Spaces:
Sleeping
Sleeping
| # DEPENDENCIES | |
| import re | |
| from enum import Enum | |
| from typing import List | |
| from typing import Dict | |
| from typing import Optional | |
| from collections import defaultdict | |
| from config.models import CitationStyle | |
| from config.models import ChunkWithScore | |
| from config.logging_config import get_logger | |
| from utils.error_handler import handle_errors | |
| from utils.error_handler import CitationFormattingError | |
| # Setup Logging | |
| logger = get_logger(__name__) | |
| class CitationFormatter: | |
| """ | |
| Citation formatting and management: Formats citations in generated text according to different styles and ensures citation consistency and validity | |
| """ | |
| def __init__(self, style: CitationStyle = CitationStyle.NUMERIC): | |
| """ | |
| Initialize citation formatter | |
| Arguments: | |
| ---------- | |
| style { CitationStyle } : Citation style to use | |
| """ | |
| self.logger = logger | |
| self.style = style | |
| self.citation_pattern = re.compile(r'\[(\d+)\]') | |
| # Style configurations | |
| self.style_configs = {CitationStyle.NUMERIC : {"inline_format" : "[{number}]", "reference_format" : "[{number}] {source_info}", "separator" : " ",}, | |
| CitationStyle.VERBOSE : {"inline_format" : "[{number}]", "reference_format" : "Citation {number}: {source_info}", "separator" : "\n",}, | |
| CitationStyle.MINIMAL : {"inline_format" : "[{number}]", "reference_format" : "[{number}]", "separator" : " ",}, | |
| CitationStyle.ACADEMIC : {"inline_format" : "({number})", "reference_format" : "{number}. {source_info}", "separator" : "\n",}, | |
| CitationStyle.LEGAL : {"inline_format" : "[{number}]", "reference_format" : "[{number}] {source_info}", "separator" : "\n",} | |
| } | |
| def format_citations_in_text(self, text: str, sources: List[ChunkWithScore]) -> str: | |
| """ | |
| Format citations in generated text | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citation markers | |
| sources { list } : List of sources for citation mapping | |
| Returns: | |
| -------- | |
| { str } : Text with formatted citations | |
| """ | |
| if not text or not sources: | |
| return text | |
| try: | |
| # Extract citation numbers from text | |
| citation_numbers = self._extract_citation_numbers(text = text) | |
| if not citation_numbers: | |
| return text | |
| # Create citation mapping | |
| citation_map = self._create_citation_map(sources = sources) | |
| # Replace citation markers with formatted citations | |
| formatted_text = self._replace_citation_markers(text = text, | |
| citation_map = citation_map, | |
| ) | |
| self.logger.debug(f"Formatted {len(citation_numbers)} citations in text") | |
| return formatted_text | |
| except Exception as e: | |
| self.logger.error(f"Citation formatting failed: {repr(e)}") | |
| raise CitationFormattingError(f"Citation formatting failed: {repr(e)}") | |
| def generate_reference_section(self, sources: List[ChunkWithScore], cited_numbers: List[int]) -> str: | |
| """ | |
| Generate reference section for cited sources | |
| Arguments: | |
| ---------- | |
| sources { list } : All available sources | |
| cited_numbers { list } : Numbers of actually cited sources | |
| Returns: | |
| -------- | |
| { str } : Formatted reference section | |
| """ | |
| if not sources or not cited_numbers: | |
| return "" | |
| try: | |
| style_config = self.style_configs[self.style] | |
| references = list() | |
| # Get only cited sources | |
| cited_sources = [sources[num-1] for num in cited_numbers if (1 <= num <= len(sources))] | |
| for i, source in enumerate(cited_sources, 1): | |
| source_info = self._format_source_info(source, i) | |
| reference = style_config["reference_format"].format(number = i, source_info = source_info) | |
| references.append(reference) | |
| separator = style_config["separator"] | |
| reference_section = separator.join(references) | |
| # Add section header if appropriate | |
| if (self.style in [CitationStyle.VERBOSE, CitationStyle.ACADEMIC]): | |
| reference_section = "References:\n" + reference_section | |
| self.logger.debug(f"Generated reference section with {len(references)} entries") | |
| return reference_section | |
| except Exception as e: | |
| self.logger.error(f"Reference section generation failed: {repr(e)}") | |
| return "" | |
| def _extract_citation_numbers(self, text: str) -> List[int]: | |
| """ | |
| Extract citation numbers from text | |
| """ | |
| matches = self.citation_pattern.findall(text) | |
| citation_numbers = [int(match) for match in matches] | |
| # Unique and sorted | |
| return sorted(set(citation_numbers)) | |
| def _create_citation_map(self, sources: List[ChunkWithScore]) -> Dict[int, str]: | |
| """ | |
| Create mapping from citation numbers to formatted citations | |
| """ | |
| citation_map = dict() | |
| style_config = self.style_configs[self.style] | |
| for i, source in enumerate(sources, 1): | |
| formatted_citation = style_config["inline_format"].format(number=i) | |
| citation_map[i] = formatted_citation | |
| return citation_map | |
| def _replace_citation_markers(self, text: str, citation_map: Dict[int, str]) -> str: | |
| """ | |
| Replace citation markers in text | |
| """ | |
| def replacement(match): | |
| citation_num = int(match.group(1)) | |
| return citation_map.get(citation_num, match.group(0)) | |
| return self.citation_pattern.sub(replacement, text) | |
| def _format_source_info(self, source: ChunkWithScore, citation_number: int) -> str: | |
| """ | |
| Format source information based on style | |
| """ | |
| chunk = source.chunk | |
| if (self.style == CitationStyle.MINIMAL): | |
| return f"Source {citation_number}" | |
| # Build source components | |
| components = list() | |
| # Document information | |
| if hasattr(chunk, 'metadata'): | |
| meta = chunk.metadata | |
| if ('filename' in meta): | |
| components.append(f"Document: {meta['filename']}") | |
| if (('title' in meta) and meta['title']): | |
| components.append(f"\"{meta['title']}\"") | |
| if (('author' in meta) and meta['author']): | |
| components.append(f"by {meta['author']}") | |
| # Location information | |
| location_parts = list() | |
| if chunk.page_number: | |
| location_parts.append(f"p. {chunk.page_number}") | |
| if chunk.section_title: | |
| location_parts.append(f"Section: {chunk.section_title}") | |
| if location_parts: | |
| components.append("(" + ", ".join(location_parts) + ")") | |
| # Relevance score (for verbose styles) | |
| if ((self.style in [CitationStyle.VERBOSE, CitationStyle.ACADEMIC]) and (source.score > 0)): | |
| components.append(f"[relevance: {source.score:.3f}]") | |
| return " ".join(components) | |
| def validate_citations(self, text: str, sources: List[ChunkWithScore]) -> tuple[bool, List[int]]: | |
| """ | |
| Validate citations in text | |
| Arguments: | |
| ---------- | |
| text { str } : Text to validate | |
| sources { list } : Available sources | |
| Returns: | |
| -------- | |
| { tuple } : (is_valid, invalid_citations) | |
| """ | |
| citation_numbers = self._extract_citation_numbers(text = text) | |
| invalid_citations = list() | |
| for number in citation_numbers: | |
| if ((number < 1) or (number > len(sources))): | |
| invalid_citations.append(number) | |
| is_valid = (len(invalid_citations) == 0) | |
| if not is_valid: | |
| self.logger.warning(f"Invalid citations found: {invalid_citations}") | |
| return is_valid, invalid_citations | |
| def normalize_citations(self, text: str, sources: List[ChunkWithScore]) -> str: | |
| """ | |
| Normalize citations to ensure sequential numbering | |
| Arguments: | |
| ---------- | |
| text { str } : Text with citations | |
| sources { list } : Available sources | |
| Returns: | |
| -------- | |
| { str } : Text with normalized citations | |
| """ | |
| citation_numbers = self._extract_citation_numbers(text = text) | |
| if not citation_numbers: | |
| return text | |
| # Create mapping from old to new numbers | |
| citation_mapping = dict() | |
| for i, old_num in enumerate(sorted(set(citation_numbers)), 1): | |
| if (1 <= old_num <= len(sources)): | |
| citation_mapping[old_num] = i | |
| # Replace citations | |
| def normalize_replacement(match): | |
| old_num = int(match.group(1)) | |
| new_num = citation_mapping.get(old_num, old_num) | |
| style_config = self.style_configs[self.style] | |
| return style_config["inline_format"].format(number = new_num) | |
| normalized_text = self.citation_pattern.sub(normalize_replacement, text) | |
| if citation_mapping: | |
| self.logger.info(f"Normalized citations: {citation_numbers} -> {list(citation_mapping.values())}") | |
| return normalized_text | |
| def get_citation_statistics(self, text: str, sources: List[ChunkWithScore]) -> Dict: | |
| """ | |
| Get citation statistics | |
| Arguments: | |
| ---------- | |
| text { str } : Text with citations | |
| sources { list } : Available sources | |
| Returns: | |
| -------- | |
| { dict } : Citation statistics | |
| """ | |
| citation_numbers = self._extract_citation_numbers(text = text) | |
| if not citation_numbers: | |
| return {"total_citations": 0} | |
| # Calculate distribution | |
| source_usage = defaultdict(int) | |
| for number in citation_numbers: | |
| if (1 <= number <= len(sources)): | |
| source = sources[number-1] | |
| doc_id = source.chunk.document_id | |
| source_usage[doc_id] += 1 | |
| return {"total_citations" : len(citation_numbers), | |
| "unique_citations" : len(set(citation_numbers)), | |
| "sources_used" : len(source_usage), | |
| "citations_per_source" : dict(source_usage), | |
| "citation_density" : len(citation_numbers) / max(1, len(text.split())), | |
| } | |
| def set_style(self, style: CitationStyle): | |
| """ | |
| Set citation style | |
| Arguments: | |
| ---------- | |
| style { CitationStyle } : New citation style | |
| """ | |
| if (style not in self.style_configs): | |
| raise CitationFormattingError(f"Unsupported citation style: {style}") | |
| old_style = self.style | |
| self.style = style | |
| self.logger.info(f"Citation style changed: {old_style} -> {style}") | |
| # Global citation formatter instance | |
| _citation_formatter = None | |
| def get_citation_formatter() -> CitationFormatter: | |
| """ | |
| Get global citation formatter instance (singleton) | |
| Returns: | |
| -------- | |
| { CitationFormatter } : CitationFormatter instance | |
| """ | |
| global _citation_formatter | |
| if _citation_formatter is None: | |
| _citation_formatter = CitationFormatter() | |
| return _citation_formatter | |
| def format_citations(text: str, sources: List[ChunkWithScore], style: CitationStyle = None) -> str: | |
| """ | |
| Convenience function for citation formatting | |
| Arguments: | |
| ---------- | |
| text { str } : Text containing citations | |
| sources { list } : List of sources | |
| style { CitationStyle } : Citation style to use | |
| Returns: | |
| -------- | |
| { str } : Formatted text | |
| """ | |
| formatter = get_citation_formatter() | |
| if style is not None: | |
| formatter.set_style(style) | |
| return formatter.format_citations_in_text(text, sources) |