File size: 4,049 Bytes
d8ba418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Utilities for tracking and formatting source citations."""
from typing import List, Dict, Any
from langchain_core.documents import Document


class CitationTracker:
    """Tracks sources and generates citation references."""
    
    def __init__(self):
        self.sources: List[Document] = []
        self.source_map: Dict[str, int] = {}
    
    def add_document(self, doc: Document) -> int:
        """

        Add a document and return its source ID.

        

        Args:

            doc: LangChain Document with metadata

            

        Returns:

            Source ID (1-indexed)

        """
        # Create unique key from metadata
        doc_key = self._create_doc_key(doc)
        
        # Return existing ID if already added
        if doc_key in self.source_map:
            return self.source_map[doc_key]
        
        # Add new source
        source_id = len(self.sources) + 1
        self.sources.append(doc)
        self.source_map[doc_key] = source_id
        
        return source_id
    
    def _create_doc_key(self, doc: Document) -> str:
        """Create unique key for document deduplication."""
        metadata = doc.metadata
        filename = metadata.get('filename', 'unknown')
        chunk_id = metadata.get('chunk_id', 'unknown')
        return f"{filename}_{chunk_id}"
    
    def format_context_with_citations(self, documents: List[Document]) -> str:
        """

        Format documents into context string with source markers.

        

        Args:

            documents: List of LangChain Documents

            

        Returns:

            Formatted context string with [Source N] markers

        """
        context_parts = []
        
        for doc in documents:
            source_id = self.add_document(doc)
            
            # Format: [Source N] content
            context_parts.append(f"[Source {source_id}] {doc.page_content}")
        
        return "\n\n".join(context_parts)
    
    def get_sources_list(self) -> List[Dict[str, Any]]:
        """

        Get formatted list of all sources.

        

        Returns:

            List of source dictionaries with metadata

        """
        sources_list = []
        
        for idx, doc in enumerate(self.sources, start=1):
            metadata = doc.metadata
            
            # Get text preview (first 200 chars)
            text_preview = doc.page_content[:200]
            if len(doc.page_content) > 200:
                text_preview += "..."
            
            # Convert chunk_id to string if it exists (FIXED)
            chunk_id = metadata.get('chunk_id')
            if chunk_id is not None:
                chunk_id = str(chunk_id)
            
            source_info = {
                "source_id": idx,
                "filename": metadata.get('filename', 'unknown'),
                "doc_type": metadata.get('doc_type', 'unknown'),
                "ticker": metadata.get('ticker'),
                "similarity_score": float(metadata.get('similarity_score', 0.0)),
                "chunk_id": chunk_id,  # Now properly converted to string
                "text_preview": text_preview
            }
            
            sources_list.append(source_info)
        
        return sources_list
    
    def clear(self):
        """Clear all tracked sources."""
        self.sources.clear()
        self.source_map.clear()


def extract_citations_from_answer(answer: str) -> List[int]:
    """

    Extract citation numbers from answer text.

    

    Args:

        answer: Generated answer with [Source N] citations

        

    Returns:

        List of unique source IDs mentioned in answer

    """
    import re
    
    # Find all [Source N] patterns
    pattern = r'\[Source (\d+)\]'
    matches = re.findall(pattern, answer)
    
    # Convert to integers and remove duplicates
    cited_sources = sorted(set(int(m) for m in matches))
    
    return cited_sources