""" DocAI Space: Document Intelligence Demo - Task 2: PDF to HTML + Figures """ import gradio as gr from PIL import Image from src.ui_state import create_initial_state, parse_cache, page_cache, hash_bytes from src.pdf_io import load_pdf_pages from src.docling_parse import parse_document from src.crops import extract_figures from src.infer_vision_qa import answer_question from src.infer_chart2csv import extract_csv def process_upload(file_path: str, session_state: dict) -> tuple: """Parse PDF and load figures.""" max_pages = 20 session_state["current_figure_index"] = 0 session_state["conversation_history"] = [] session_state["current_image_path"] = None if not file_path: return "Please upload a PDF file.", "
No document loaded
", "No figures", "", None, session_state try: with open(file_path, 'rb') as f: file_bytes = f.read() file_hash = hash_bytes(file_bytes) session_state["uploaded_file_hash"] = file_hash session_state["uploaded_file_bytes"] = file_bytes status_lines = ["PDF loaded successfully."] # Load pages (check shared cache first) cache_key = f"{file_hash}_{max_pages}" if cache_key in page_cache: page_images = page_cache[cache_key] else: page_images = load_pdf_pages(file_bytes, max_pages=max_pages) page_cache[cache_key] = page_images session_state["page_images"] = page_images status_lines.append(f"Number of pages rendered: {len(page_images)} (max {max_pages}).") # Parse with Docling (check shared cache first) if file_hash in parse_cache: parse_result = parse_cache[file_hash] else: parse_result = parse_document(file_bytes) parse_cache[file_hash] = parse_result session_state["parsed_result"] = parse_result status_lines.append("Document parsing done using Docling.") # Extract figures figures_info = extract_figures(page_images, parse_result.get("figures", [])) session_state["figures_info"] = figures_info status_lines.append(f"Number of figures extracted: {len(figures_info)}.") # Select first figure if figures_info: session_state["selected_figure"] = figures_info[0] fig_status = f"Figure 1 of {len(figures_info)} (Page {figures_info[0]['page'] + 1})" fig_caption = figures_info[0].get('caption', 'No caption') fig_image = figures_info[0]['image'] else: session_state["selected_figure"] = None fig_status = "No figures found" fig_caption = "" fig_image = None # Get HTML html_content = parse_result.get("html", "No HTML available
") status = "\n".join(status_lines) return status, html_content, fig_status, fig_caption, fig_image, session_state except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() error_msg = f"Error: {str(e)}" return error_msg, "Error loading document
", "Error", "", None, session_state def _get_figure_display(session_state: dict) -> tuple: """Helper to get current figure info, caption, and image from session state.""" figures_info = session_state.get("figures_info", []) idx = session_state.get("current_figure_index", 0) if not figures_info: return "No figures found", "", None fig = figures_info[idx] fig_status = f"Figure {idx + 1} of {len(figures_info)} (Page {fig['page'] + 1})" fig_caption = fig.get('caption', 'No caption') return fig_status, fig_caption, fig['image'] def next_figure(session_state: dict) -> tuple: """Show next figure.""" figures_info = session_state.get("figures_info", []) if not figures_info: return "No figures found", "", None, session_state idx = (session_state.get("current_figure_index", 0) + 1) % len(figures_info) session_state["current_figure_index"] = idx session_state["selected_figure"] = figures_info[idx] session_state["conversation_history"] = [] session_state["current_image_path"] = None fig_status, fig_caption, fig_image = _get_figure_display(session_state) return fig_status, fig_caption, fig_image, session_state def prev_figure(session_state: dict) -> tuple: """Show previous figure.""" figures_info = session_state.get("figures_info", []) if not figures_info: return "No figures found", "", None, session_state idx = (session_state.get("current_figure_index", 0) - 1) % len(figures_info) session_state["current_figure_index"] = idx session_state["selected_figure"] = figures_info[idx] session_state["conversation_history"] = [] session_state["current_image_path"] = None fig_status, fig_caption, fig_image = _get_figure_display(session_state) return fig_status, fig_caption, fig_image, session_state def ask_question_helper(question: str, session_state: dict) -> tuple: """Ask question about selected figure.""" if not question: return "Enter a question", session_state selected_fig = session_state.get("selected_figure") if selected_fig is None: return "No figure selected", session_state try: image = selected_fig['image'] history = session_state.get("conversation_history", []) image_path = session_state.get("current_image_path") result, updated_history, updated_image_path = answer_question( image, question, history, image_path ) session_state["conversation_history"] = updated_history session_state["current_image_path"] = updated_image_path return f"Q: {question}\n\nA: {result}", session_state except Exception as e: return f"Error: {str(e)}", session_state def load_current_figure(session_state: dict) -> tuple: """Load the current figure from session state into display components.""" fig_status, fig_caption, fig_image = _get_figure_display(session_state) return fig_status, fig_caption, fig_image def extract_csv_helper(session_state: dict) -> tuple: """Extract CSV from selected chart.""" selected_fig = session_state.get("selected_figure") if selected_fig is None: return "No figure selected", session_state try: image = selected_fig['image'] csv_text = extract_csv(image) session_state["last_csv"] = csv_text return f"```csv\n{csv_text}\n```", session_state except Exception as e: return f"Error: {str(e)}", session_state # BUILD APP WITH SIDE-BY-SIDE LAYOUT with gr.Blocks(title="DocAI") as app: gr.Markdown("# Granite Vision 4.0 Demo: Document Intelligence") session_state = gr.State(create_initial_state) with gr.Tabs(): # TAB 1: UPLOAD & PARSE with gr.Tab("Parse & Extract"): with gr.Row(): file_path = gr.File(label="Upload PDF", file_types=[".pdf"], scale=4) load_btn = gr.Button("Load PDF", variant="primary", scale=1) status = gr.Textbox(label="Status", interactive=False, lines=2) # Side-by-side: HTML on left, Figures on right with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Parsed Document using Docling") html_view = gr.HTML(value="Upload a PDF to see parsed content
") with gr.Column(scale=1): gr.Markdown("### Extracted Figures") fig_info = gr.Textbox(label="Figure Info", interactive=False) fig_caption = gr.Textbox(label="Caption", interactive=False) fig_image = gr.Image(label="Figure", type="pil") with gr.Row(): prev_btn = gr.Button("Previous", scale=1) next_btn = gr.Button("Next", scale=1) # Wire callbacks load_btn.click( process_upload, inputs=[file_path, session_state], outputs=[status, html_view, fig_info, fig_caption, fig_image, session_state] ) next_btn.click( next_figure, inputs=[session_state], outputs=[fig_info, fig_caption, fig_image, session_state] ) prev_btn.click( prev_figure, inputs=[session_state], outputs=[fig_info, fig_caption, fig_image, session_state] ) # TAB 2: IMAGE Q&A with gr.Tab("Image Q&A") as qa_tab: gr.Markdown("Ask questions about the selected figure") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Figure") qa_fig_info = gr.Textbox(label="Figure Info", interactive=False) qa_fig_caption = gr.Textbox(label="Caption", interactive=False) qa_fig_image = gr.Image(label="Figure", type="pil") with gr.Row(): qa_prev_btn = gr.Button("Previous", scale=1) qa_next_btn = gr.Button("Next", scale=1) with gr.Column(scale=1): gr.Markdown("### Q&A") question = gr.Textbox(label="Question", lines=2, placeholder="e.g., What is shown in this chart?") ask_btn = gr.Button("Ask", variant="primary") answer = gr.Textbox(label="Answer", lines=8, interactive=False) # Wire callbacks for figure navigation in Q&A tab qa_prev_btn.click( prev_figure, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state] ) qa_next_btn.click( next_figure, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state] ) ask_btn.click( ask_question_helper, inputs=[question, session_state], outputs=[answer, session_state] ) # Populate figure display when switching to this tab qa_tab.select( load_current_figure, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image] ) # TAB 3: CSV EXTRACTION with gr.Tab("Chart2CSV") as csv_tab: gr.Markdown("Extract CSV data from the selected chart") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Figure") csv_fig_info = gr.Textbox(label="Figure Info", interactive=False) csv_fig_caption = gr.Textbox(label="Caption", interactive=False) csv_fig_image = gr.Image(label="Figure", type="pil") with gr.Row(): csv_prev_btn = gr.Button("Previous", scale=1) csv_next_btn = gr.Button("Next", scale=1) with gr.Column(scale=1): gr.Markdown("### CSV Extraction") extract_btn = gr.Button("Extract CSV", variant="primary") csv_out = gr.Textbox(label="CSV", lines=8, interactive=False) csv_prev_btn.click( prev_figure, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state] ) csv_next_btn.click( next_figure, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state] ) extract_btn.click( extract_csv_helper, inputs=[session_state], outputs=[csv_out, session_state] ) csv_tab.select( load_current_figure, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image] ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860, share=False)