DocAI / app.py
Pengyuan Li
Rename demo title to Granite Vision 4.0 Demo: Document Intelligence
a0b7bd5
"""
DocAI Space: Document Intelligence Demo - Task 2: PDF to HTML + Figures
"""
import gradio as gr
from PIL import Image
from src.ui_state import create_initial_state, parse_cache, page_cache, hash_bytes
from src.pdf_io import load_pdf_pages
from src.docling_parse import parse_document
from src.crops import extract_figures
from src.infer_vision_qa import answer_question
from src.infer_chart2csv import extract_csv
def process_upload(file_path: str, session_state: dict) -> tuple:
"""Parse PDF and load figures."""
max_pages = 20
session_state["current_figure_index"] = 0
session_state["conversation_history"] = []
session_state["current_image_path"] = None
if not file_path:
return "Please upload a PDF file.", "<p>No document loaded</p>", "No figures", "", None, session_state
try:
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_hash = hash_bytes(file_bytes)
session_state["uploaded_file_hash"] = file_hash
session_state["uploaded_file_bytes"] = file_bytes
status_lines = ["PDF loaded successfully."]
# Load pages (check shared cache first)
cache_key = f"{file_hash}_{max_pages}"
if cache_key in page_cache:
page_images = page_cache[cache_key]
else:
page_images = load_pdf_pages(file_bytes, max_pages=max_pages)
page_cache[cache_key] = page_images
session_state["page_images"] = page_images
status_lines.append(f"Number of pages rendered: {len(page_images)} (max {max_pages}).")
# Parse with Docling (check shared cache first)
if file_hash in parse_cache:
parse_result = parse_cache[file_hash]
else:
parse_result = parse_document(file_bytes)
parse_cache[file_hash] = parse_result
session_state["parsed_result"] = parse_result
status_lines.append("Document parsing done using Docling.")
# Extract figures
figures_info = extract_figures(page_images, parse_result.get("figures", []))
session_state["figures_info"] = figures_info
status_lines.append(f"Number of figures extracted: {len(figures_info)}.")
# Select first figure
if figures_info:
session_state["selected_figure"] = figures_info[0]
fig_status = f"Figure 1 of {len(figures_info)} (Page {figures_info[0]['page'] + 1})"
fig_caption = figures_info[0].get('caption', 'No caption')
fig_image = figures_info[0]['image']
else:
session_state["selected_figure"] = None
fig_status = "No figures found"
fig_caption = ""
fig_image = None
# Get HTML
html_content = parse_result.get("html", "<p>No HTML available</p>")
status = "\n".join(status_lines)
return status, html_content, fig_status, fig_caption, fig_image, session_state
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
error_msg = f"Error: {str(e)}"
return error_msg, "<p>Error loading document</p>", "Error", "", None, session_state
def _get_figure_display(session_state: dict) -> tuple:
"""Helper to get current figure info, caption, and image from session state."""
figures_info = session_state.get("figures_info", [])
idx = session_state.get("current_figure_index", 0)
if not figures_info:
return "No figures found", "", None
fig = figures_info[idx]
fig_status = f"Figure {idx + 1} of {len(figures_info)} (Page {fig['page'] + 1})"
fig_caption = fig.get('caption', 'No caption')
return fig_status, fig_caption, fig['image']
def next_figure(session_state: dict) -> tuple:
"""Show next figure."""
figures_info = session_state.get("figures_info", [])
if not figures_info:
return "No figures found", "", None, session_state
idx = (session_state.get("current_figure_index", 0) + 1) % len(figures_info)
session_state["current_figure_index"] = idx
session_state["selected_figure"] = figures_info[idx]
session_state["conversation_history"] = []
session_state["current_image_path"] = None
fig_status, fig_caption, fig_image = _get_figure_display(session_state)
return fig_status, fig_caption, fig_image, session_state
def prev_figure(session_state: dict) -> tuple:
"""Show previous figure."""
figures_info = session_state.get("figures_info", [])
if not figures_info:
return "No figures found", "", None, session_state
idx = (session_state.get("current_figure_index", 0) - 1) % len(figures_info)
session_state["current_figure_index"] = idx
session_state["selected_figure"] = figures_info[idx]
session_state["conversation_history"] = []
session_state["current_image_path"] = None
fig_status, fig_caption, fig_image = _get_figure_display(session_state)
return fig_status, fig_caption, fig_image, session_state
def ask_question_helper(question: str, session_state: dict) -> tuple:
"""Ask question about selected figure."""
if not question:
return "Enter a question", session_state
selected_fig = session_state.get("selected_figure")
if selected_fig is None:
return "No figure selected", session_state
try:
image = selected_fig['image']
history = session_state.get("conversation_history", [])
image_path = session_state.get("current_image_path")
result, updated_history, updated_image_path = answer_question(
image, question, history, image_path
)
session_state["conversation_history"] = updated_history
session_state["current_image_path"] = updated_image_path
return f"Q: {question}\n\nA: {result}", session_state
except Exception as e:
return f"Error: {str(e)}", session_state
def load_current_figure(session_state: dict) -> tuple:
"""Load the current figure from session state into display components."""
fig_status, fig_caption, fig_image = _get_figure_display(session_state)
return fig_status, fig_caption, fig_image
def extract_csv_helper(session_state: dict) -> tuple:
"""Extract CSV from selected chart."""
selected_fig = session_state.get("selected_figure")
if selected_fig is None:
return "No figure selected", session_state
try:
image = selected_fig['image']
csv_text = extract_csv(image)
session_state["last_csv"] = csv_text
return f"```csv\n{csv_text}\n```", session_state
except Exception as e:
return f"Error: {str(e)}", session_state
# BUILD APP WITH SIDE-BY-SIDE LAYOUT
with gr.Blocks(title="DocAI") as app:
gr.Markdown("# Granite Vision 4.0 Demo: Document Intelligence")
session_state = gr.State(create_initial_state)
with gr.Tabs():
# TAB 1: UPLOAD & PARSE
with gr.Tab("Parse & Extract"):
with gr.Row():
file_path = gr.File(label="Upload PDF", file_types=[".pdf"], scale=4)
load_btn = gr.Button("Load PDF", variant="primary", scale=1)
status = gr.Textbox(label="Status", interactive=False, lines=2)
# Side-by-side: HTML on left, Figures on right
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Parsed Document using Docling")
html_view = gr.HTML(value="<p>Upload a PDF to see parsed content</p>")
with gr.Column(scale=1):
gr.Markdown("### Extracted Figures")
fig_info = gr.Textbox(label="Figure Info", interactive=False)
fig_caption = gr.Textbox(label="Caption", interactive=False)
fig_image = gr.Image(label="Figure", type="pil")
with gr.Row():
prev_btn = gr.Button("Previous", scale=1)
next_btn = gr.Button("Next", scale=1)
# Wire callbacks
load_btn.click(
process_upload,
inputs=[file_path, session_state],
outputs=[status, html_view, fig_info, fig_caption, fig_image, session_state]
)
next_btn.click(
next_figure,
inputs=[session_state],
outputs=[fig_info, fig_caption, fig_image, session_state]
)
prev_btn.click(
prev_figure,
inputs=[session_state],
outputs=[fig_info, fig_caption, fig_image, session_state]
)
# TAB 2: IMAGE Q&A
with gr.Tab("Image Q&A") as qa_tab:
gr.Markdown("Ask questions about the selected figure")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Figure")
qa_fig_info = gr.Textbox(label="Figure Info", interactive=False)
qa_fig_caption = gr.Textbox(label="Caption", interactive=False)
qa_fig_image = gr.Image(label="Figure", type="pil")
with gr.Row():
qa_prev_btn = gr.Button("Previous", scale=1)
qa_next_btn = gr.Button("Next", scale=1)
with gr.Column(scale=1):
gr.Markdown("### Q&A")
question = gr.Textbox(label="Question", lines=2, placeholder="e.g., What is shown in this chart?")
ask_btn = gr.Button("Ask", variant="primary")
answer = gr.Textbox(label="Answer", lines=8, interactive=False)
# Wire callbacks for figure navigation in Q&A tab
qa_prev_btn.click(
prev_figure,
inputs=[session_state],
outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state]
)
qa_next_btn.click(
next_figure,
inputs=[session_state],
outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state]
)
ask_btn.click(
ask_question_helper,
inputs=[question, session_state],
outputs=[answer, session_state]
)
# Populate figure display when switching to this tab
qa_tab.select(
load_current_figure,
inputs=[session_state],
outputs=[qa_fig_info, qa_fig_caption, qa_fig_image]
)
# TAB 3: CSV EXTRACTION
with gr.Tab("Chart2CSV") as csv_tab:
gr.Markdown("Extract CSV data from the selected chart")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Figure")
csv_fig_info = gr.Textbox(label="Figure Info", interactive=False)
csv_fig_caption = gr.Textbox(label="Caption", interactive=False)
csv_fig_image = gr.Image(label="Figure", type="pil")
with gr.Row():
csv_prev_btn = gr.Button("Previous", scale=1)
csv_next_btn = gr.Button("Next", scale=1)
with gr.Column(scale=1):
gr.Markdown("### CSV Extraction")
extract_btn = gr.Button("Extract CSV", variant="primary")
csv_out = gr.Textbox(label="CSV", lines=8, interactive=False)
csv_prev_btn.click(
prev_figure,
inputs=[session_state],
outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state]
)
csv_next_btn.click(
next_figure,
inputs=[session_state],
outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state]
)
extract_btn.click(
extract_csv_helper,
inputs=[session_state],
outputs=[csv_out, session_state]
)
csv_tab.select(
load_current_figure,
inputs=[session_state],
outputs=[csv_fig_info, csv_fig_caption, csv_fig_image]
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860, share=False)