Spaces:
Paused
Paused
| ##################################################### | |
| ### DOCUMENT PROCESSOR [CITATION] | |
| ##################################################### | |
| # Jonathan Wang | |
| # ABOUT: | |
| # This project creates an app to chat with PDFs. | |
| # This is the CITATION | |
| # which adds citation information to the LLM response | |
| ##################################################### | |
| ## TODO Board: | |
| # Investigate using LLM model weights with attention to determien citations. | |
| # https://gradientscience.org/contextcite/ | |
| # https://github.com/MadryLab/context-cite/blob/main/context_cite/context_citer.py#L25 | |
| # https://github.com/MadryLab/context-cite/blob/main/context_cite/context_partitioner.py | |
| # https://github.com/MadryLab/context-cite/blob/main/context_cite/solver.py | |
| ##################################################### | |
| ## IMPORTS | |
| from __future__ import annotations | |
| from collections import defaultdict | |
| from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING | |
| import warnings | |
| import numpy as np | |
| from llama_index.core.base.response.schema import RESPONSE_TYPE, Response | |
| if TYPE_CHECKING: | |
| from llama_index.core.schema import NodeWithScore | |
| # Own Modules | |
| from merger import _merge_on_scores | |
| from rapidfuzz import fuzz, process, utils | |
| # Lazy Loading: | |
| # from nltk import sent_tokenize # noqa: ERA001 | |
| ##################################################### | |
| ## CODE | |
| class CitationBuilder: | |
| """Class that builds citations from responses.""" | |
| text_splitter: Callable[[str], list[str]] | |
| def __init__(self, text_splitter: Callable[[str], list[str]] | None = None) -> None: | |
| if not text_splitter: | |
| from nltk import sent_tokenize | |
| text_splitter = sent_tokenize | |
| self.text_splitter = text_splitter | |
| def class_name(cls) -> str: | |
| return "CitationBuilder" | |
| def convert_to_response(self, input_response: RESPONSE_TYPE) -> Response: | |
| # Convert all other response types into the baseline response | |
| # Otherwise, we won't have the full response text generated. | |
| if not isinstance(input_response, Response): | |
| response = input_response.get_response() | |
| if isinstance(response, Response): | |
| return response | |
| else: | |
| # TODO(Jonathan Wang): Handle async responses with Coroutines | |
| msg = "Expected Response object, got Coroutine" | |
| raise TypeError(msg) | |
| else: | |
| return input_response | |
| def find_nearest_whitespace( | |
| self, | |
| input_text: str, | |
| input_index: int, | |
| right_to_left: bool=False | |
| ) -> int: | |
| """Given a sting and an index, find the index of whitespace closest to the string.""" | |
| if (input_index < 0 or input_index >= len(input_text)): | |
| msg = "find_nearest_whitespace: index beyond string." | |
| raise ValueError(msg) | |
| find_text = "" | |
| if (right_to_left): | |
| find_text = input_text[:input_index] | |
| for index, char in enumerate(reversed(find_text)): | |
| if (char.isspace()): | |
| return (len(find_text)-1 - index) | |
| return (0) | |
| else: | |
| find_text = input_text[input_index:] | |
| for index, char in enumerate(find_text): | |
| if (char.isspace()): | |
| return (input_index + index) | |
| return (len(input_text)) | |
| def get_citations( | |
| self, | |
| input_response: RESPONSE_TYPE, | |
| citation_threshold: int = 70, | |
| citation_len: int = 128 | |
| ) -> Response: | |
| response = self.convert_to_response(input_response) | |
| if not response.response or not response.source_nodes: | |
| return response | |
| # Get current response text: | |
| response_text = response.response | |
| source_nodes = response.source_nodes | |
| # 0. Get candidate nodes for citation. | |
| # Fuzzy match each source node text against the respone text. | |
| source_texts: dict[str, list[NodeWithScore]] = defaultdict(list) | |
| for node in source_nodes: | |
| if ( | |
| (len(getattr(node.node, "text", "")) > 0) and | |
| (len(node.node.metadata) > 0) | |
| ): # filter out non-text nodes and intermediate nodes from SubQueryQuestionEngine | |
| source_texts[node.node.text].append(node) # type: ignore | |
| fuzzy_matches = process.extract( | |
| response_text, | |
| list(source_texts.keys()), | |
| scorer=fuzz.partial_ratio, | |
| processor=utils.default_process, | |
| score_cutoff=max(10, citation_threshold - 10) | |
| ) | |
| # Convert extracted matches of form (Match, Score, Rank) into scores for all source_texts. | |
| if fuzzy_matches: | |
| fuzzy_texts, _, _ = zip(*fuzzy_matches) | |
| fuzzy_nodes = [source_texts[text][0] for text in fuzzy_texts] | |
| else: | |
| return response | |
| # 1. Combine fuzzy score and source text semantic/reranker score. | |
| # NOTE: for our merge here, we value the nodes with strong fuzzy text matching over other node types. | |
| cited_nodes = _merge_on_scores( | |
| a_list=fuzzy_nodes, | |
| b_list=source_nodes, # same nodes, different scores (fuzzy vs semantic/bm25/reranker) | |
| a_scores_input=[getattr(node, "score", np.nan) for node in fuzzy_nodes], | |
| b_scores_input=[getattr(node, "score", np.nan) for node in source_nodes], | |
| a_weight=0.85, # we want to heavily prioritize the fuzzy text for matches | |
| top_k=3 # maximum of three source options. | |
| ) | |
| # 2. Add cited nodes text to the response text, and cited nodes as metadata. | |
| # For each sentence in the response, if there is a match in the source text, add a citation tag. | |
| response_sentences = self.text_splitter(response_text) | |
| output_text = "" | |
| output_citations = "" | |
| citation_tag = 0 | |
| for response_sentence in response_sentences: | |
| # Get fuzzy citation at sentence level | |
| best_alignment = None | |
| best_score = 0 | |
| best_node = None | |
| for _, source_node in enumerate(source_nodes): | |
| source_node_text = getattr(source_node.node, "text", "") | |
| new_alignment = fuzz.partial_ratio_alignment( | |
| response_sentence, | |
| source_node_text, | |
| processor=utils.default_process, score_cutoff=citation_threshold | |
| ) | |
| new_score = 0.0 | |
| if (new_alignment is not None and (new_alignment.src_end - new_alignment.src_start) > 0): | |
| new_score = fuzz.ratio( | |
| source_node_text[new_alignment.src_start:new_alignment.src_end], | |
| response_sentence[new_alignment.dest_start:new_alignment.dest_end], | |
| processor=utils.default_process | |
| ) | |
| new_score = new_score * (new_alignment.src_end - new_alignment.src_start) / float(len(response_sentence)) | |
| if (new_score > best_score): | |
| best_alignment = new_alignment | |
| best_score = new_score | |
| best_node = source_node | |
| if (best_score <= 0 or best_node is None or best_alignment is None): | |
| # No match | |
| output_text += response_sentence | |
| continue | |
| # Add citation tag to text | |
| citation_tag_position = self.find_nearest_whitespace(response_sentence, best_alignment.dest_start, right_to_left=True) | |
| output_text += response_sentence[:citation_tag_position] # response up to the quote | |
| output_text += f" [{citation_tag}] " # add citation tag | |
| output_text += response_sentence[citation_tag_position:] # reposnse after the quote | |
| # Add citation text to citations | |
| citation = getattr(best_node.node, "text", "") | |
| citation_margin = round((citation_len - (best_alignment.src_end - best_alignment.src_start)) / 2) | |
| nearest_whitespace_pre = self.find_nearest_whitespace(citation, max(0, best_alignment.src_start), right_to_left=True) | |
| nearest_whitespace_post = self.find_nearest_whitespace(citation, min(len(citation)-1, best_alignment.src_end), right_to_left=False) | |
| nearest_whitespace_prewindow = self.find_nearest_whitespace(citation, max(0, nearest_whitespace_pre - citation_margin), right_to_left=True) | |
| nearest_whitespace_postwindow = self.find_nearest_whitespace(citation, min(len(citation)-1, nearest_whitespace_post + citation_margin), right_to_left=False) | |
| citation_text = ( | |
| citation[nearest_whitespace_prewindow+1: nearest_whitespace_pre+1] | |
| + "|||||" | |
| + citation[nearest_whitespace_pre+1:nearest_whitespace_post] | |
| + "|||||" | |
| + citation[nearest_whitespace_post:nearest_whitespace_postwindow] | |
| + f"β¦ <<{best_node.node.metadata.get('name', '')}, Page(s) {best_node.node.metadata.get('page_number', '')}>>" | |
| ) | |
| output_citations += f"[{citation_tag}]: {citation_text}\n\n" | |
| citation_tag += 1 | |
| # Create output | |
| if response.metadata is not None: | |
| # NOTE: metadata is certainly existant by now, but the schema allows None... | |
| response.metadata["cited_nodes"] = cited_nodes | |
| response.metadata["citations"] = output_citations | |
| response.response = output_text # update response to include citation tags | |
| return response | |
| def add_citations_to_response(self, input_response: Response) -> Response: | |
| if not hasattr(input_response, "metadata"): | |
| msg = "Input response does not have metadata." | |
| raise ValueError(msg) | |
| elif input_response.metadata is None or "citations" not in input_response.metadata: | |
| warnings.warn("Input response does not have citations.", stacklevel=2) | |
| input_response = self.get_citations(input_response) | |
| # Add citation text to response | |
| if (hasattr(input_response, "metadata") and input_response.metadata.get("citations", "") != ""): | |
| input_response.response = ( | |
| input_response.response | |
| + "\n\n----- CITATIONS -----\n\n" | |
| + input_response.metadata.get('citations', "") | |
| ) # type: ignore | |
| return input_response | |
| def __call__(self, input_response: RESPONSE_TYPE, *args: Any, **kwds: Any) -> Response: | |
| return self.get_citations(input_response, *args, **kwds) | |
| def get_citation_builder() -> CitationBuilder: | |
| return CitationBuilder() |