Spaces:
Runtime error
Runtime error
| """Utilities for tracking and formatting source citations.""" | |
| from typing import List, Dict, Any | |
| from langchain_core.documents import Document | |
| class CitationTracker: | |
| """Tracks sources and generates citation references.""" | |
| def __init__(self): | |
| self.sources: List[Document] = [] | |
| self.source_map: Dict[str, int] = {} | |
| def add_document(self, doc: Document) -> int: | |
| """ | |
| Add a document and return its source ID. | |
| Args: | |
| doc: LangChain Document with metadata | |
| Returns: | |
| Source ID (1-indexed) | |
| """ | |
| # Create unique key from metadata | |
| doc_key = self._create_doc_key(doc) | |
| # Return existing ID if already added | |
| if doc_key in self.source_map: | |
| return self.source_map[doc_key] | |
| # Add new source | |
| source_id = len(self.sources) + 1 | |
| self.sources.append(doc) | |
| self.source_map[doc_key] = source_id | |
| return source_id | |
| def _create_doc_key(self, doc: Document) -> str: | |
| """Create unique key for document deduplication.""" | |
| metadata = doc.metadata | |
| filename = metadata.get('filename', 'unknown') | |
| chunk_id = metadata.get('chunk_id', 'unknown') | |
| return f"{filename}_{chunk_id}" | |
| def format_context_with_citations(self, documents: List[Document]) -> str: | |
| """ | |
| Format documents into context string with source markers. | |
| Args: | |
| documents: List of LangChain Documents | |
| Returns: | |
| Formatted context string with [Source N] markers | |
| """ | |
| context_parts = [] | |
| for doc in documents: | |
| source_id = self.add_document(doc) | |
| # Format: [Source N] content | |
| context_parts.append(f"[Source {source_id}] {doc.page_content}") | |
| return "\n\n".join(context_parts) | |
| def get_sources_list(self) -> List[Dict[str, Any]]: | |
| """ | |
| Get formatted list of all sources. | |
| Returns: | |
| List of source dictionaries with metadata | |
| """ | |
| sources_list = [] | |
| for idx, doc in enumerate(self.sources, start=1): | |
| metadata = doc.metadata | |
| # Get text preview (first 200 chars) | |
| text_preview = doc.page_content[:200] | |
| if len(doc.page_content) > 200: | |
| text_preview += "..." | |
| # Convert chunk_id to string if it exists (FIXED) | |
| chunk_id = metadata.get('chunk_id') | |
| if chunk_id is not None: | |
| chunk_id = str(chunk_id) | |
| source_info = { | |
| "source_id": idx, | |
| "filename": metadata.get('filename', 'unknown'), | |
| "doc_type": metadata.get('doc_type', 'unknown'), | |
| "ticker": metadata.get('ticker'), | |
| "similarity_score": float(metadata.get('similarity_score', 0.0)), | |
| "chunk_id": chunk_id, # Now properly converted to string | |
| "text_preview": text_preview | |
| } | |
| sources_list.append(source_info) | |
| return sources_list | |
| def clear(self): | |
| """Clear all tracked sources.""" | |
| self.sources.clear() | |
| self.source_map.clear() | |
| def extract_citations_from_answer(answer: str) -> List[int]: | |
| """ | |
| Extract citation numbers from answer text. | |
| Args: | |
| answer: Generated answer with [Source N] citations | |
| Returns: | |
| List of unique source IDs mentioned in answer | |
| """ | |
| import re | |
| # Find all [Source N] patterns | |
| pattern = r'\[Source (\d+)\]' | |
| matches = re.findall(pattern, answer) | |
| # Convert to integers and remove duplicates | |
| cited_sources = sorted(set(int(m) for m in matches)) | |
| return cited_sources |