| | """ |
| | DocAI Space: Document Intelligence Demo - Task 2: PDF to HTML + Figures |
| | """ |
| |
|
| | import gradio as gr |
| | from PIL import Image |
| |
|
| | from src.ui_state import create_initial_state, parse_cache, page_cache, hash_bytes |
| | from src.pdf_io import load_pdf_pages |
| | from src.docling_parse import parse_document |
| | from src.crops import extract_figures |
| | from src.infer_vision_qa import answer_question |
| | from src.infer_chart2csv import extract_csv |
| |
|
| |
|
| | def process_upload(file_path: str, session_state: dict) -> tuple: |
| | """Parse PDF and load figures.""" |
| | max_pages = 20 |
| |
|
| | session_state["current_figure_index"] = 0 |
| | session_state["conversation_history"] = [] |
| | session_state["current_image_path"] = None |
| |
|
| | if not file_path: |
| | return "Please upload a PDF file.", "<p>No document loaded</p>", "No figures", "", None, session_state |
| |
|
| | try: |
| | with open(file_path, 'rb') as f: |
| | file_bytes = f.read() |
| |
|
| | file_hash = hash_bytes(file_bytes) |
| | session_state["uploaded_file_hash"] = file_hash |
| | session_state["uploaded_file_bytes"] = file_bytes |
| |
|
| | status_lines = ["PDF loaded successfully."] |
| |
|
| | |
| | cache_key = f"{file_hash}_{max_pages}" |
| | if cache_key in page_cache: |
| | page_images = page_cache[cache_key] |
| | else: |
| | page_images = load_pdf_pages(file_bytes, max_pages=max_pages) |
| | page_cache[cache_key] = page_images |
| | session_state["page_images"] = page_images |
| | status_lines.append(f"Number of pages rendered: {len(page_images)} (max {max_pages}).") |
| |
|
| | |
| | if file_hash in parse_cache: |
| | parse_result = parse_cache[file_hash] |
| | else: |
| | parse_result = parse_document(file_bytes) |
| | parse_cache[file_hash] = parse_result |
| | session_state["parsed_result"] = parse_result |
| | status_lines.append("Document parsing done using Docling.") |
| |
|
| | |
| | figures_info = extract_figures(page_images, parse_result.get("figures", [])) |
| | session_state["figures_info"] = figures_info |
| | status_lines.append(f"Number of figures extracted: {len(figures_info)}.") |
| |
|
| | |
| | if figures_info: |
| | session_state["selected_figure"] = figures_info[0] |
| | fig_status = f"Figure 1 of {len(figures_info)} (Page {figures_info[0]['page'] + 1})" |
| | fig_caption = figures_info[0].get('caption', 'No caption') |
| | fig_image = figures_info[0]['image'] |
| | else: |
| | session_state["selected_figure"] = None |
| | fig_status = "No figures found" |
| | fig_caption = "" |
| | fig_image = None |
| |
|
| | |
| | html_content = parse_result.get("html", "<p>No HTML available</p>") |
| | status = "\n".join(status_lines) |
| |
|
| | return status, html_content, fig_status, fig_caption, fig_image, session_state |
| |
|
| | except Exception as e: |
| | print(f"Error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | error_msg = f"Error: {str(e)}" |
| | return error_msg, "<p>Error loading document</p>", "Error", "", None, session_state |
| |
|
| |
|
| | def _get_figure_display(session_state: dict) -> tuple: |
| | """Helper to get current figure info, caption, and image from session state.""" |
| | figures_info = session_state.get("figures_info", []) |
| | idx = session_state.get("current_figure_index", 0) |
| |
|
| | if not figures_info: |
| | return "No figures found", "", None |
| |
|
| | fig = figures_info[idx] |
| | fig_status = f"Figure {idx + 1} of {len(figures_info)} (Page {fig['page'] + 1})" |
| | fig_caption = fig.get('caption', 'No caption') |
| | return fig_status, fig_caption, fig['image'] |
| |
|
| |
|
| | def next_figure(session_state: dict) -> tuple: |
| | """Show next figure.""" |
| | figures_info = session_state.get("figures_info", []) |
| |
|
| | if not figures_info: |
| | return "No figures found", "", None, session_state |
| |
|
| | idx = (session_state.get("current_figure_index", 0) + 1) % len(figures_info) |
| | session_state["current_figure_index"] = idx |
| | session_state["selected_figure"] = figures_info[idx] |
| | session_state["conversation_history"] = [] |
| | session_state["current_image_path"] = None |
| |
|
| | fig_status, fig_caption, fig_image = _get_figure_display(session_state) |
| | return fig_status, fig_caption, fig_image, session_state |
| |
|
| |
|
| | def prev_figure(session_state: dict) -> tuple: |
| | """Show previous figure.""" |
| | figures_info = session_state.get("figures_info", []) |
| |
|
| | if not figures_info: |
| | return "No figures found", "", None, session_state |
| |
|
| | idx = (session_state.get("current_figure_index", 0) - 1) % len(figures_info) |
| | session_state["current_figure_index"] = idx |
| | session_state["selected_figure"] = figures_info[idx] |
| | session_state["conversation_history"] = [] |
| | session_state["current_image_path"] = None |
| |
|
| | fig_status, fig_caption, fig_image = _get_figure_display(session_state) |
| | return fig_status, fig_caption, fig_image, session_state |
| |
|
| |
|
| | def ask_question_helper(question: str, session_state: dict) -> tuple: |
| | """Ask question about selected figure.""" |
| | if not question: |
| | return "Enter a question", session_state |
| |
|
| | selected_fig = session_state.get("selected_figure") |
| | if selected_fig is None: |
| | return "No figure selected", session_state |
| |
|
| | try: |
| | image = selected_fig['image'] |
| | history = session_state.get("conversation_history", []) |
| | image_path = session_state.get("current_image_path") |
| |
|
| | result, updated_history, updated_image_path = answer_question( |
| | image, question, history, image_path |
| | ) |
| |
|
| | session_state["conversation_history"] = updated_history |
| | session_state["current_image_path"] = updated_image_path |
| |
|
| | return f"Q: {question}\n\nA: {result}", session_state |
| | except Exception as e: |
| | return f"Error: {str(e)}", session_state |
| |
|
| |
|
| | def load_current_figure(session_state: dict) -> tuple: |
| | """Load the current figure from session state into display components.""" |
| | fig_status, fig_caption, fig_image = _get_figure_display(session_state) |
| | return fig_status, fig_caption, fig_image |
| |
|
| |
|
| | def extract_csv_helper(session_state: dict) -> tuple: |
| | """Extract CSV from selected chart.""" |
| | selected_fig = session_state.get("selected_figure") |
| | if selected_fig is None: |
| | return "No figure selected", session_state |
| |
|
| | try: |
| | image = selected_fig['image'] |
| | csv_text = extract_csv(image) |
| | session_state["last_csv"] = csv_text |
| | return f"```csv\n{csv_text}\n```", session_state |
| | except Exception as e: |
| | return f"Error: {str(e)}", session_state |
| |
|
| |
|
| | |
| | with gr.Blocks(title="DocAI") as app: |
| | gr.Markdown("# Granite Vision 4.0 Demo: Document Intelligence") |
| |
|
| | session_state = gr.State(create_initial_state) |
| |
|
| | with gr.Tabs(): |
| | |
| | with gr.Tab("Parse & Extract"): |
| | with gr.Row(): |
| | file_path = gr.File(label="Upload PDF", file_types=[".pdf"], scale=4) |
| | load_btn = gr.Button("Load PDF", variant="primary", scale=1) |
| |
|
| | status = gr.Textbox(label="Status", interactive=False, lines=2) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.Markdown("### Parsed Document using Docling") |
| | html_view = gr.HTML(value="<p>Upload a PDF to see parsed content</p>") |
| |
|
| | with gr.Column(scale=1): |
| | gr.Markdown("### Extracted Figures") |
| | fig_info = gr.Textbox(label="Figure Info", interactive=False) |
| | fig_caption = gr.Textbox(label="Caption", interactive=False) |
| | fig_image = gr.Image(label="Figure", type="pil") |
| |
|
| | with gr.Row(): |
| | prev_btn = gr.Button("Previous", scale=1) |
| | next_btn = gr.Button("Next", scale=1) |
| |
|
| | |
| | load_btn.click( |
| | process_upload, |
| | inputs=[file_path, session_state], |
| | outputs=[status, html_view, fig_info, fig_caption, fig_image, session_state] |
| | ) |
| | next_btn.click( |
| | next_figure, |
| | inputs=[session_state], |
| | outputs=[fig_info, fig_caption, fig_image, session_state] |
| | ) |
| | prev_btn.click( |
| | prev_figure, |
| | inputs=[session_state], |
| | outputs=[fig_info, fig_caption, fig_image, session_state] |
| | ) |
| |
|
| | |
| | with gr.Tab("Image Q&A") as qa_tab: |
| | gr.Markdown("Ask questions about the selected figure") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.Markdown("### Figure") |
| | qa_fig_info = gr.Textbox(label="Figure Info", interactive=False) |
| | qa_fig_caption = gr.Textbox(label="Caption", interactive=False) |
| | qa_fig_image = gr.Image(label="Figure", type="pil") |
| |
|
| | with gr.Row(): |
| | qa_prev_btn = gr.Button("Previous", scale=1) |
| | qa_next_btn = gr.Button("Next", scale=1) |
| |
|
| | with gr.Column(scale=1): |
| | gr.Markdown("### Q&A") |
| | question = gr.Textbox(label="Question", lines=2, placeholder="e.g., What is shown in this chart?") |
| | ask_btn = gr.Button("Ask", variant="primary") |
| | answer = gr.Textbox(label="Answer", lines=8, interactive=False) |
| |
|
| | |
| | qa_prev_btn.click( |
| | prev_figure, |
| | inputs=[session_state], |
| | outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state] |
| | ) |
| | qa_next_btn.click( |
| | next_figure, |
| | inputs=[session_state], |
| | outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state] |
| | ) |
| | ask_btn.click( |
| | ask_question_helper, |
| | inputs=[question, session_state], |
| | outputs=[answer, session_state] |
| | ) |
| |
|
| | |
| | qa_tab.select( |
| | load_current_figure, |
| | inputs=[session_state], |
| | outputs=[qa_fig_info, qa_fig_caption, qa_fig_image] |
| | ) |
| |
|
| | |
| | with gr.Tab("Chart2CSV") as csv_tab: |
| | gr.Markdown("Extract CSV data from the selected chart") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.Markdown("### Figure") |
| | csv_fig_info = gr.Textbox(label="Figure Info", interactive=False) |
| | csv_fig_caption = gr.Textbox(label="Caption", interactive=False) |
| | csv_fig_image = gr.Image(label="Figure", type="pil") |
| |
|
| | with gr.Row(): |
| | csv_prev_btn = gr.Button("Previous", scale=1) |
| | csv_next_btn = gr.Button("Next", scale=1) |
| |
|
| | with gr.Column(scale=1): |
| | gr.Markdown("### CSV Extraction") |
| | extract_btn = gr.Button("Extract CSV", variant="primary") |
| | csv_out = gr.Textbox(label="CSV", lines=8, interactive=False) |
| |
|
| | csv_prev_btn.click( |
| | prev_figure, |
| | inputs=[session_state], |
| | outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state] |
| | ) |
| | csv_next_btn.click( |
| | next_figure, |
| | inputs=[session_state], |
| | outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state] |
| | ) |
| | extract_btn.click( |
| | extract_csv_helper, |
| | inputs=[session_state], |
| | outputs=[csv_out, session_state] |
| | ) |
| | csv_tab.select( |
| | load_current_figure, |
| | inputs=[session_state], |
| | outputs=[csv_fig_info, csv_fig_caption, csv_fig_image] |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.launch(server_name="0.0.0.0", server_port=7860, share=False) |
| |
|