adeyemi001's picture
Upload project excluding env and notebook
d8ba418 verified
"""Utilities for tracking and formatting source citations."""
from typing import List, Dict, Any
from langchain_core.documents import Document
class CitationTracker:
"""Tracks sources and generates citation references."""
def __init__(self):
self.sources: List[Document] = []
self.source_map: Dict[str, int] = {}
def add_document(self, doc: Document) -> int:
"""
Add a document and return its source ID.
Args:
doc: LangChain Document with metadata
Returns:
Source ID (1-indexed)
"""
# Create unique key from metadata
doc_key = self._create_doc_key(doc)
# Return existing ID if already added
if doc_key in self.source_map:
return self.source_map[doc_key]
# Add new source
source_id = len(self.sources) + 1
self.sources.append(doc)
self.source_map[doc_key] = source_id
return source_id
def _create_doc_key(self, doc: Document) -> str:
"""Create unique key for document deduplication."""
metadata = doc.metadata
filename = metadata.get('filename', 'unknown')
chunk_id = metadata.get('chunk_id', 'unknown')
return f"{filename}_{chunk_id}"
def format_context_with_citations(self, documents: List[Document]) -> str:
"""
Format documents into context string with source markers.
Args:
documents: List of LangChain Documents
Returns:
Formatted context string with [Source N] markers
"""
context_parts = []
for doc in documents:
source_id = self.add_document(doc)
# Format: [Source N] content
context_parts.append(f"[Source {source_id}] {doc.page_content}")
return "\n\n".join(context_parts)
def get_sources_list(self) -> List[Dict[str, Any]]:
"""
Get formatted list of all sources.
Returns:
List of source dictionaries with metadata
"""
sources_list = []
for idx, doc in enumerate(self.sources, start=1):
metadata = doc.metadata
# Get text preview (first 200 chars)
text_preview = doc.page_content[:200]
if len(doc.page_content) > 200:
text_preview += "..."
# Convert chunk_id to string if it exists (FIXED)
chunk_id = metadata.get('chunk_id')
if chunk_id is not None:
chunk_id = str(chunk_id)
source_info = {
"source_id": idx,
"filename": metadata.get('filename', 'unknown'),
"doc_type": metadata.get('doc_type', 'unknown'),
"ticker": metadata.get('ticker'),
"similarity_score": float(metadata.get('similarity_score', 0.0)),
"chunk_id": chunk_id, # Now properly converted to string
"text_preview": text_preview
}
sources_list.append(source_info)
return sources_list
def clear(self):
"""Clear all tracked sources."""
self.sources.clear()
self.source_map.clear()
def extract_citations_from_answer(answer: str) -> List[int]:
"""
Extract citation numbers from answer text.
Args:
answer: Generated answer with [Source N] citations
Returns:
List of unique source IDs mentioned in answer
"""
import re
# Find all [Source N] patterns
pattern = r'\[Source (\d+)\]'
matches = re.findall(pattern, answer)
# Convert to integers and remove duplicates
cited_sources = sorted(set(int(m) for m in matches))
return cited_sources