Spaces:
Build error
Build error
| # Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
| # All rights reserved. | |
| # See the top-level LICENSE and NOTICE files for details. | |
| # LLNL-CODE-838964 | |
| # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
| import torch | |
| import gradio as gr | |
| from pathlib import Path | |
| from torchvision.transforms import ToPILImage, ToTensor | |
| tensor_to_image = ToPILImage() | |
| image_to_tensor = ToTensor() | |
| import sys | |
| sys.path.append('DiT_Extractor/') | |
| sys.path.append('CrossEncoder/') | |
| sys.path.append('UnifiedQA/') | |
| import dit_runner | |
| import sentence_extractor | |
| import cross_encoder | |
| import demo_QA | |
| from torchvision.transforms import ToPILImage | |
| tensor_to_image = ToPILImage() | |
| def run_fn(pdf_file_obj, question_text, input_topk): | |
| pdf = pdf_file_obj.name | |
| print('Running PDF: {0}'.format(pdf)) | |
| viz_images = dit_runner.get_dit_preds(pdf, score_threshold=0.5) | |
| entity_json = '{0}.json'.format(Path(pdf).name[:-4]) | |
| sentence_extractor.get_contexts(entity_json) | |
| contexts_json = 'contexts_{0}'.format(entity_json) | |
| # contexts_json = 'contexts_2105u2iwiwxh.03011.json' | |
| cross_encoder.get_ranked_contexts(contexts_json, question_text) | |
| ranked_contexts_json = 'ranked_{0}'.format(contexts_json) | |
| # ranked_contexts_json = 'ranked_contexts_2105u2iwiwxh.03011.json' | |
| input_topk = int(input_topk) | |
| # viz_images = [tensor_to_image(x) for x in torch.randn(4, 3, 256, 256)] | |
| qa_results = demo_QA.get_qa_results(contexts_json, ranked_contexts_json, input_topk) | |
| history = [('<<< [Retrieval Score: {0:.02f}] >>> {1}'.format(s, c), a) for c, s, a in zip(qa_results['contexts'], qa_results['context_scores'], qa_results['answers'])] | |
| # Show in ascending order of score, since results box is already scrolled down. | |
| history = history[::-1] | |
| return viz_images, contexts_json, ranked_contexts_json, history | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("<h1><center>Detect-Retrieve-Comprehend for Document-Level QA</center></h1>") | |
| gr.Markdown("<center>This is a supplemental demo for our recent paper, expected to be publically available around October: <b>Detect, Retrieve, Comprehend: A Flexible Framework for Zero-Shot Document-Level Question Answering</b>. In this system, our input is a PDF file with a specific question of interest. The output is a set of most probable answers. There are 4 main components in our deployed pipeline: (1) DiT Layout Analysis (2) Context Extraction (3) Cross-Encoder Retrieval (4) UnifiedQA. See below for example uses with further explanation. Note that demo runtimes may be between 2-8 minutes, since this is currently cpu-based Space.</center>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_pdf_file = gr.File(file_count='single', label='PDF File') | |
| with gr.Row(): | |
| input_question_text = gr.Textbox(label='Question') | |
| with gr.Row(): | |
| input_k_percent = gr.Slider(minimum=1, maximum=24, step=1, value=8, label='Top K') | |
| with gr.Row(): | |
| button_run = gr.Button('Run QA on Document') | |
| gr.Markdown("<h3><center>Summary</center></h3>") | |
| with gr.Row(): | |
| gr.Markdown(''' | |
| - <u>**DiT - Document Image Transformer**</u>: PDF -> converted into a list of images -> each image receives Entity Predictions | |
| - Note that using this computer vision approach allows us to ignore things like *page numbers, footnotes, references*, etc | |
| - <u>**Paragraph-based Text Extraction**</u>: DiT Bounding Boxes -> Convert into PDF-Space Coordinates -> Text Extraction using PDFMiner6 -> Tokenize & Sentence Split if tokenizer max length is exceeded | |
| - <u>**CrossEncoder Context Retrieval**</u>: All Contexts + Question -> Top K Relevant Contexts best suited for answering question | |
| - <u>**UnifiedQA**</u>: Most Relevant Contexts + Supplied Question -> Predict Set of Probable Answers | |
| ''') | |
| with gr.Row(): | |
| examples = [ | |
| ['examples/1909.00694.pdf', 'What is the seed lexicon?', 5], | |
| ['examples/1909.00694.pdf', 'How big is seed lexicon used for training?', 5], | |
| ['examples/1810.04805.pdf', 'What is this paper about?', 5], | |
| ['examples/1810.04805.pdf', 'What is the model size?', 5], | |
| ['examples/2105.03011.pdf', 'How many questions are in this dataset?', 5], | |
| ['examples/1909.00694.pdf', 'How are relations used to propagate polarity?', 5], | |
| ] | |
| gr.Examples(examples=examples, | |
| inputs=[input_pdf_file, input_question_text, input_k_percent]) | |
| with gr.Column(): | |
| with gr.Row(): | |
| output_gallery = gr.Gallery(label='DiT Predicted Entities') | |
| with gr.Row(): | |
| gr.Markdown(''' | |
| - The `DiT predicted Entities` output box is scrollable! Scroll to see different page predictions. Note that predictions with confidence scores < 0.5 are not passed forward for text extraction. | |
| - If an image is clicked, the output box will switch to a gallery view. To view these outputs in much higher resolution, right-click and choose "open image in new tab" | |
| ''') | |
| with gr.Row(): | |
| output_contexts = gr.File(label='Detected Contexts', interactive=False) | |
| output_ranked_contexts = gr.File(label='CrossEncoder Ranked Contexts', interactive=False) | |
| with gr.Row(): | |
| output_qa_results = gr.Chatbot(color_map=['blue', 'green'], label='UnifiedQA Results').style() | |
| gr.Markdown("<h3><center>Related Work & Code</center></h3>") | |
| gr.Markdown("<center>DiT (Document Image Transformer) - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>") | |
| gr.Markdown("<center>CrossEncoder - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>") | |
| gr.Markdown("<center>UnifiedQA - <a href=https://arxiv.org/abs/2005.00700>Arxiv Page</a> | <a href=https://github.com/allenai/unifiedqa>Github Repo</a></center>") | |
| button_run.click(fn=run_fn, inputs=[input_pdf_file, input_question_text, input_k_percent], outputs=[output_gallery, output_contexts, output_ranked_contexts, output_qa_results]) | |
| demo.launch(enable_queue=True) |