Spaces:
Running
Running
| """ app.py | |
| Question / answer over the Canadian Income Tax Act. | |
| https://laws-lois.justice.gc.ca/eng/acts/i-3.3/ | |
| Retrieval model: | |
| - LanceDB: support for hybrid search search with reranking of results. | |
| - Full text search (lexical): BM25 | |
| - Vector search (semantic dense vectors): BAAI/bge-m3 | |
| Rerankers: | |
| - ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI | |
| Generation: | |
| - Mistral | |
| :author: Didier Guillevic | |
| :date: 2025-04-27 | |
| """ | |
| import gradio as gr | |
| import lancedb | |
| import llm_utils | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| # | |
| # LanceDB with the indexed documents | |
| # | |
| # Connect to the database | |
| lance_db = lancedb.connect("lance.db") | |
| lance_tbl = lance_db.open_table("income_tax_act") | |
| # | |
| # Retrieval: query types and reranker types | |
| # | |
| query_types = { | |
| 'lexical': 'fts', | |
| 'semantic': 'vector', | |
| 'hybrid': 'hybrid', | |
| } | |
| # Define a few rerankers | |
| colbert_reranker = lancedb.rerankers.ColbertReranker(column='text') | |
| answerai_reranker = lancedb.rerankers.AnswerdotaiRerankers(column='text') | |
| crossencoder_reranker = lancedb.rerankers.CrossEncoderReranker(column='text') | |
| reciprocal_rank_fusion_reranker = lancedb.rerankers.RRFReranker() # hybrid search only | |
| reranker_types = { | |
| 'ColBERT': colbert_reranker, | |
| 'cross encoder': crossencoder_reranker, | |
| 'AnswerAI': answerai_reranker, | |
| 'Reciprocal Rank Fusion': reciprocal_rank_fusion_reranker | |
| } | |
| def search_table( | |
| table: lancedb.table, | |
| query: str, | |
| query_type: str, | |
| reranker_name: str, | |
| top_k: int=5, | |
| overfetch_factor: int=2 | |
| ): | |
| # Get the instance of reranker | |
| reranker = reranker_types.get(reranker_name) | |
| if reranker is None: | |
| logger.error(f"Invalid reranker name: {reranker_name}") | |
| raise ValueError(f"Invalid reranker selected: {reranker_name}") | |
| if query_type in ["vector", "fts"]: | |
| if reranker == reciprocal_rank_fusion_reranker: | |
| # reciprocal is for 'hybrid' search type only | |
| reranker = crossencoder_reranker | |
| results = ( | |
| table.search(query, query_type=query_type) | |
| .limit(top_k * overfetch_factor) | |
| .rerank(reranker=reranker) | |
| .limit(top_k) | |
| .to_list() # to get access to '_relevance_score' | |
| ) | |
| elif query_type == "hybrid": | |
| results = ( | |
| table.search(query, query_type=query_type) | |
| .limit(top_k * overfetch_factor) | |
| .rerank(reranker=reranker) | |
| .limit(top_k) | |
| .to_list() # to get access to '_relevance_score' | |
| ) | |
| return results[:top_k] | |
| # | |
| # Generatton: query + context --> response | |
| # | |
| def create_bulleted_list(texts: list[str], scores: list[float]=None) -> str: | |
| """ | |
| This function takes a list of strings and returns HTML with a bulleted list. | |
| """ | |
| html_items = [] | |
| if scores is not None: | |
| for text, score in zip(texts, scores): | |
| html_items.append(f"<li>(Score={score:.2f})\t{text}</li>") | |
| else: | |
| for text in texts: | |
| html_items.append(f"<li>{text}</li>") | |
| return "<ul>" + "".join(html_items) + "</ul>" | |
| def generate_response( | |
| query: str, | |
| query_type: str, | |
| reranker_name: str, | |
| top_k: int | |
| ) -> list[str, str, str]: | |
| """Generate a response given query, search type and reranker. | |
| Args: | |
| Returns: | |
| - the response given the snippets extracted from the database | |
| - (html string): the references (origin of the snippets of text used to generate the answer) | |
| - (html string): the snippets of text used to generate the answer | |
| """ | |
| # Get results from LanceDB | |
| results = search_table( | |
| lance_tbl, | |
| query=query, | |
| query_type=query_type, | |
| reranker_name=reranker_name, | |
| top_k=top_k | |
| ) | |
| references = [ | |
| ( | |
| f"{result['hlabel1']} {result['htitletext1']} " | |
| f"{result['division']} {result['subdivision']} " | |
| f"{result['section']}" | |
| ) for result in results | |
| ] | |
| references_html = "<h4>References</h4>\n" + create_bulleted_list(references) | |
| snippets = [result['text'] for result in results] | |
| scores = [result['_relevance_score'] for result in results] | |
| snippets_html = "<h4>Snippets</h4>\n" + create_bulleted_list(snippets, scores) | |
| # Generate the response from the LLM | |
| stream_response = llm_utils.generate_chat_response_streaming( | |
| query, '\n\n'.join(snippets) | |
| ) | |
| model_response = "" | |
| for chunk in stream_response: | |
| model_response += chunk.data.choices[0].delta.content | |
| yield model_response, references_html, snippets_html | |
| # | |
| # User interface | |
| # | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # Income Tax Act | |
| Question / answer over the Canadian [Income Tax Act](https://laws-lois.justice.gc.ca/eng/acts/i-3.3/) | |
| """) | |
| # Inputs: question | |
| question = gr.Textbox( | |
| label="Question to answer", | |
| placeholder="" | |
| ) | |
| # Response | |
| response = gr.Markdown( | |
| label="Response" | |
| ) | |
| # Button | |
| with gr.Row(): | |
| response_button = gr.Button("Submit", variant='primary') | |
| cancel_button = gr.Button("Cancel", variant='stop', visible=False) | |
| clear_button = gr.Button("Clear", variant='secondary') | |
| # Additional inputs | |
| query_type = gr.Dropdown( | |
| choices=query_types.items(), | |
| value='hybrid', | |
| label='Query type', | |
| render=False | |
| ) | |
| reranker_name = gr.Dropdown( | |
| choices=list(reranker_types.keys()), | |
| value='cross encoder', | |
| label='Reranker', | |
| render=False | |
| ) | |
| top_k = gr.Slider( | |
| minimum=2, maximum=5, value=5, step=1, | |
| label='Top k results', render=False | |
| ) | |
| # Snippets, sample questions, search parameters, documentation | |
| with gr.Accordion("Snippets / sample questions / search parameters / documentation", open=False): | |
| # References and snippets | |
| with gr.Accordion("References & snippets", open=False): | |
| references = gr.HTML(label="References") | |
| snippets = gr.HTML(label="Snippets") | |
| # Example questions given default provided PDF file | |
| with gr.Accordion("Sample questions", open=False): | |
| gr.Examples( | |
| [ | |
| ["What is considered 'tax avoidance' under the General Anti-Avoidance Provision (GAAR)?",], | |
| ['What is "section 160 avoidance planning"?',], | |
| ['What are "tax shelters" and "reportable transactions", and what penalties can apply?',], | |
| ["Are there penalties for making false statements on receipts for tax deductions or credits? ",], | |
| ["Under what conditions is a lease agreement deemed not to be a lease for tax purposes?",], | |
| ["What are the conditions under which a loan or debt may be considered to have a tax avoidance motive regarding the income of a particular individual?",], | |
| ["When can the Minister reassess a tax return related to property transfers?",], | |
| ], | |
| inputs=[question, query_type, reranker_name, top_k], | |
| outputs=[response, references, snippets], | |
| fn=generate_response, | |
| cache_examples=False, | |
| label="Sample questions" | |
| ) | |
| # Additional inputs: search parameters | |
| with gr.Accordion("Search parameters", open=False): | |
| with gr.Row(): | |
| query_type.render() | |
| reranker_name.render() | |
| top_k.render() | |
| # Documentation | |
| with gr.Accordion("Documentation", open=False): | |
| gr.Markdown(""" | |
| - Retrieval model | |
| - LanceDB: support for hybrid search search with reranking of results. | |
| - Full text search (lexical): BM25 | |
| - Vector search (semantic dense vectors): BAAI/bge-m3 | |
| - Rerankers | |
| - ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI | |
| - Generation | |
| - Mistral | |
| - Examples | |
| - Generated using Google NotebookLM | |
| """) | |
| # Click actions | |
| click_event = response_button.click( | |
| fn=lambda: (gr.Button(visible=False), gr.Button(visible=True)), | |
| inputs=[], | |
| outputs=[response_button, cancel_button], | |
| api_name="query" | |
| ).then( | |
| fn=generate_response, | |
| inputs=[question, query_type, reranker_name, top_k], | |
| outputs=[response, references, snippets], | |
| ).then( | |
| fn=lambda: (gr.Button(visible=True), gr.Button(visible=False)), | |
| inputs=[], | |
| outputs=[response_button, cancel_button], | |
| ) | |
| cancel_button.click( | |
| fn=lambda: (gr.Button(visible=True), gr.Button(visible=False)), | |
| inputs=[], | |
| outputs=[response_button, cancel_button], | |
| cancels=[click_event], | |
| ) | |
| clear_button.click( | |
| fn=lambda: ('', '', '', ''), | |
| inputs=[], | |
| outputs=[question, response, references, snippets] | |
| ) | |
| demo.queue().launch(show_api=True) | |