File size: 6,120 Bytes
d60bab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import re
from typing import List, Dict, Any, Union
import ast
from langchain_core.messages import SystemMessage, HumanMessage
from .prompts import system_prompt

# ---------------------------------------------------------------------
# Core Processing Functions
# ---------------------------------------------------------------------
def _parse_citations(response: str) -> List[int]:
    """Parse citation numbers from response text"""
    citation_pattern = r'\[(\d+)\]'
    matches = re.findall(citation_pattern, response)
    citation_numbers = sorted(list(set(int(match) for match in matches)))
    
    return citation_numbers

def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
    """Extract sources that were cited in the response"""
    if not cited_numbers:
        return []
    
    cited_sources = []
    for citation_num in cited_numbers:
        source_index = citation_num - 1
        
        if 0 <= source_index < len(processed_results):
            source = processed_results[source_index].copy()  # Make copy to avoid modifying original
            source['_citation_number'] = citation_num  # Preserve original citation number
            cited_sources.append(source)
    
    return cited_sources

def clean_citations(response: str) -> str:
    """Normalize all citation formats to [x] and remove unwanted sections"""
    
    # Remove References/Sources/Bibliography sections
    ref_patterns = [
        r'\n\s*#+\s*References?\s*:?.*$',
        r'\n\s*#+\s*Sources?\s*:?.*$',
        r'\n\s*#+\s*Bibliography\s*:?.*$',
        r'\n\s*References?\s*:.*$',
        r'\n\s*Sources?\s*:.*$',
        r'\n\s*Bibliography\s*:.*$',
    ]
    for pattern in ref_patterns:
        response = re.sub(pattern, '', response, flags=re.IGNORECASE | re.DOTALL)
    
    # Fix (Document X, Page Y, Year Z) -> [X]
    response = re.sub(
        r'\(Document\s+(\d+)(?:,\s*Page\s+\d+)?(?:,\s*(?:Year\s+)?\d+)?\)',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Fix [Document X, Page Y, Year Z] -> [X]
    response = re.sub(
        r'\[Document\s+(\d+)(?:[^\]]*)\]', 
        r'[\1]', 
        response, 
        flags=re.IGNORECASE
    )
    
    # Fix [Document X: filename, Page Y, Year Z] -> [X]
    response = re.sub(
        r'\[Document\s+(\d+):[^\]]+\]',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Fix [X.Y.Z] style (section numbers) -> [X]
    response = re.sub(
        r'\[(\d+)\.[\d\.]+\]', 
        r'[\1]', 
        response
    )
    
    # Fix (Document X) -> [X]
    response = re.sub(
        r'\(Document\s+(\d+)\)', 
        r'[\1]', 
        response, 
        flags=re.IGNORECASE
    )
    
    # Fix "Document X, Page Y, Year Z" (no brackets) -> [X]
    response = re.sub(
        r'Document\s+(\d+)(?:,\s*Page\s+\d+)?(?:,\s*(?:Year\s+)?\d+)?(?=\s|[,.])',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Fix "Document X states/says/mentions" -> [X]
    response = re.sub(
        r'Document\s+(\d+)\s+(?:states|says|mentions|reports|indicates|notes|shows)',
        r'[\1]',
        response,
        flags=re.IGNORECASE
    )
    
    # Clean up any double citations [[1]] -> [1]
    response = re.sub(r'\[\[(\d+)\]\]', r'[\1]', response)
    
    # Clean up multiple spaces
    response = re.sub(r'\s+', ' ', response)
    
    return response.strip()

def _process_context(context: Union[str, List[Dict[str, Any]]]) -> tuple[str, List[Dict[str, Any]]]:
    """Process context and return formatted context string and processed results"""
    processed_results = []
    
    if isinstance(context, list):
        if not context:
            raise ValueError("No retrieval results provided")
        
        # Extract relevant fields from retrieval results
        for result in context:
            if isinstance(result, str):
                result = ast.literal_eval(result)
            
            metadata = result.get('answer_metadata', {})
            doc_info = {
                'answer': result.get('answer', ''),
                'filename': metadata.get('filename', 'Unknown'),
                'page': metadata.get('page', 'Unknown'),
                'year': metadata.get('year', 'Unknown'),
                'source': metadata.get('source', 'Unknown'),
                'document_id': metadata.get('_id', 'Unknown')
            }
            processed_results.append(doc_info)
        
        # Format context string - SIMPLIFIED TO ONLY USE [1], [2], [3]
        context_parts = []
        for i, result in enumerate(processed_results, 1):
            # Simple format: [1], [2], etc.
            context_parts.append(f"[{i}]\n{result['answer']}\n")
        
        formatted_context = "\n".join(context_parts)
        
    elif isinstance(context, str):
        if not context.strip():
            raise ValueError("Context cannot be empty")
        formatted_context = context
    else:
        raise ValueError("Context must be either a string or list of retrieval results")
    
    return formatted_context, processed_results

def _build_messages(system_prompt: str, question: str, context: str) -> list:
    """Build messages for LLM call"""
    system_content = system_prompt
    user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
    return [SystemMessage(content=system_content), HumanMessage(content=user_content)]

def _create_sources_list(cited_sources: List[Dict[str, Any]]) -> List[Dict[str, str]]:
    """Create sources list for ChatUI format"""
    sources = []
    for result in cited_sources:
        filename = result.get('filename', 'Unknown')
        page = result.get('page', 'Unknown')
        year = result.get('year', 'Unknown')
        
        link = f"doc://{filename}"
        title_parts = [filename]
        if page != 'Unknown':
            title_parts.append(f"Page {page}")
        if year != 'Unknown':
            title_parts.append(f"({year})")
        
        sources.append({"link": link, "title": " - ".join(title_parts)})
    
    return sources