NVIDIA_RETR / app.py
AkshitShubham's picture
Update app.py
defdb61 verified
import os
import requests
import json
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import io
import base64
import re
import fitz
import zipfile
import tempfile
import time
import math
from datetime import datetime
import pandas as pd
# --- Configuration ---
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
if not NVIDIA_API_KEY:
raise ValueError("NVIDIA_API_KEY environment variable not set.")
NIM_API_URL = "https://integrate.api.nvidia.com/v1/chat/completions"
HEADERS = {
"Authorization": f"Bearer {NVIDIA_API_KEY}",
"Accept": "application/json",
"Content-Type": "application/json",
}
MODEL_MAX_WIDTH = 1648
MODEL_MAX_HEIGHT = 2048
# --- Folder Setup for PDF Output ---
OUTPUT_FOLDER = 'output_reports'
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Global store for processed data (key is session_id)
PROCESSED_PAGES_STORE = {}
CROPPED_QUESTIONS_STORE = {}
# --- Helper Functions (Image Processing, API Calls) ---
def resize_image_if_needed(image: Image.Image) -> Image.Image:
width, height = image.size
if width > MODEL_MAX_WIDTH or height > MODEL_MAX_HEIGHT:
ratio = min(MODEL_MAX_WIDTH / width, MODEL_MAX_HEIGHT / height)
new_width = int(width * ratio)
new_height = int(height * ratio)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return image
def call_parse_api_base64(image_bytes: bytes):
try:
base64_encoded_data = base64.b64encode(image_bytes)
base64_string = base64_encoded_data.decode('utf-8')
image_url = f"data:image/png;base64,{base64_string}"
payload = {
"model": "nvidia/nemoretriever-parse",
"messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}]}],
"tools": [{"type": "function", "function": {"name": "markdown_bbox"}}],
"max_tokens": 2048,
}
response = requests.post(NIM_API_URL, headers=HEADERS, json=payload, timeout=300)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
error_detail = str(e)
if e.response is not None:
try:
error_detail = e.response.json().get("detail", e.response.text)
except json.JSONDecodeError:
error_detail = e.response.text
if "maximum context length" in str(error_detail):
raise gr.Error(f"API Error: The document page is too dense for the model's context limit. Details: {error_detail}")
raise gr.Error(f"API Error: {error_detail}")
def get_question_number(text: str) -> int:
match = re.match(r"^\d+", text.strip())
return int(match.group(0)) if match else -1
def parse_page_ranges(range_str: str) -> set:
if not range_str: return set()
pages = set()
parts = range_str.split(',')
for part in parts:
part = part.strip()
if not part: continue
try:
if '-' in part:
start, end = map(int, part.split('-'))
if start > end: continue
pages.update(range(start, end + 1))
else:
pages.add(int(part))
except ValueError:
continue
return pages
# --- Core Cropping Logic ---
def process_and_crop(original_image: Image.Image, api_response: dict, split_page: bool):
try:
tool_call = api_response["choices"][0]["message"]["tool_calls"][0]
arguments_str = tool_call["function"]["arguments"]
all_elements = json.loads(arguments_str)[0]
except (KeyError, IndexError, json.JSONDecodeError):
return original_image, [], [], 0
question_starts = [elem for elem in all_elements if get_question_number(elem.get("text", "")) > 0]
if not question_starts:
return original_image, [], [], 0
image_with_boxes = original_image.copy()
img_draw = ImageDraw.Draw(image_with_boxes)
all_cropped_questions = []
if split_page:
page_midpoint = 0.5
left_starts = sorted([q for q in question_starts if q['bbox']['xmin'] < page_midpoint], key=lambda q: q['bbox']['ymin'])
right_starts = sorted([q for q in question_starts if q['bbox']['xmin'] >= page_midpoint], key=lambda q: q['bbox']['ymin'])
process_column(left_starts, all_elements, (0.0, page_midpoint), img_draw, original_image, all_cropped_questions)
process_column(right_starts, all_elements, (page_midpoint, 1.0), img_draw, original_image, all_cropped_questions)
else:
sorted_starts = sorted(question_starts, key=lambda q: q['bbox']['ymin'])
process_column(sorted_starts, all_elements, (0.0, 1.0), img_draw, original_image, all_cropped_questions)
all_cropped_questions.sort(key=lambda item: item[0])
final_gallery_images = [item[1] for item in all_cropped_questions]
return image_with_boxes, final_gallery_images, all_cropped_questions, len(all_cropped_questions)
def process_column(column_starts, all_elements, column_bounds, img_draw, original_image, cropped_questions_list):
img_width, img_height = original_image.size
MIN_CROP_WIDTH, MIN_CROP_HEIGHT = 100, 50
for i, start_element in enumerate(column_starts):
q_num = get_question_number(start_element['text'])
slice_ymin = start_element['bbox']['ymin']
next_ymin = 1.0
if i + 1 < len(column_starts):
next_ymin = column_starts[i+1]['bbox']['ymin']
elements_in_slice = [
e for e in all_elements if
slice_ymin <= e['bbox']['ymin'] < next_ymin and
column_bounds[0] <= e['bbox']['xmin'] < column_bounds[1]
]
if not elements_in_slice: continue
crop_xmin = min(e['bbox']['xmin'] for e in elements_in_slice)
crop_xmax = max(e['bbox']['xmax'] for e in elements_in_slice)
crop_ymax = max(e['bbox']['ymax'] for e in elements_in_slice)
abs_box = (crop_xmin * img_width, slice_ymin * img_height, crop_xmax * img_width, crop_ymax * img_height)
if (abs_box[2] - abs_box[0]) < MIN_CROP_WIDTH or (abs_box[3] - abs_box[1]) < MIN_CROP_HEIGHT:
continue
img_draw.rectangle(abs_box, outline="red", width=3)
cropped_img = original_image.crop(abs_box)
question_text = start_element.get('text', '').strip()
clean_text = re.sub(r'[^\w\s-]', '', question_text)[:50].strip()
clean_text = re.sub(r'\s+', '_', clean_text)
filename = f"{q_num}-{clean_text}" if clean_text else f"Q_{q_num}"
cropped_questions_list.append((q_num, cropped_img, filename))
# --- ZIP Download Functions ---
def zip_selected_questions(selected_indices_str: str, session_id: str):
if session_id not in CROPPED_QUESTIONS_STORE:
raise gr.Error("No processed questions found.")
cropped_questions = CROPPED_QUESTIONS_STORE[session_id]
if not cropped_questions:
raise gr.Error("No questions were extracted.")
selected_indices = parse_page_ranges(selected_indices_str) if selected_indices_str.strip() else {item[0] for item in cropped_questions}
if not selected_indices:
raise gr.Error("Please enter valid question numbers/ranges.")
zip_path = os.path.join(tempfile.gettempdir(), f"questions_{session_id}.zip")
with zipfile.ZipFile(zip_path, 'w') as zf:
for q_num, img, filename in cropped_questions:
if q_num in selected_indices:
img_io = io.BytesIO()
img.save(img_io, format='PNG')
zf.writestr(f"{filename}.png", img_io.getvalue())
return zip_path
def zip_selected_pages(selected_indices_str: str, session_id: str):
if session_id not in PROCESSED_PAGES_STORE:
raise gr.Error("No processed results found.")
processed_pages = PROCESSED_PAGES_STORE[session_id]
if not processed_pages:
raise gr.Error("No pages were processed.")
selected_indices = parse_page_ranges(selected_indices_str) if selected_indices_str.strip() else set(range(1, len(processed_pages) + 1))
if not selected_indices:
raise gr.Error("Please enter valid page numbers/ranges.")
zip_path = os.path.join(tempfile.gettempdir(), f"pages_{session_id}.zip")
with zipfile.ZipFile(zip_path, 'w') as zf:
for user_page_num in selected_indices:
list_index = user_page_num - 1
if 0 <= list_index < len(processed_pages):
img = processed_pages[list_index]
img_io = io.BytesIO()
img.save(img_io, format='PNG')
zf.writestr(f"Page_{user_page_num}_boxed.png", img_io.getvalue())
return zip_path
# --- PDF Generation Functions (Integrated) ---
def get_or_download_font(font_path="arial.ttf", font_size=50):
if not os.path.exists(font_path):
try:
print("Downloading arial.ttf font...")
response = requests.get("https://github.com/matomo-org/travis-scripts/raw/master/fonts/arial.ttf", timeout=30)
response.raise_for_status()
with open(font_path, 'wb') as f: f.write(response.content)
print("Font downloaded.")
except Exception as e:
print(f"Font download failed: {e}. Using default font.")
return ImageFont.load_default()
try:
return ImageFont.truetype(font_path, size=font_size)
except IOError:
print("Arial font not found or failed to load. Using default font.")
return ImageFont.load_default()
def create_a4_pdf_from_images(image_info, images_per_page, pdf_filename_base, orientation="Auto", progress=None):
if not image_info: return None
A4_PORTRAIT_WIDTH, A4_PORTRAIT_HEIGHT = 2480, 3508
font_large, font_small = get_or_download_font(font_size=40), get_or_download_font(font_size=28)
output_filename = f"{pdf_filename_base}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf"
output_path = os.path.join(OUTPUT_FOLDER, output_filename)
pages = []
info_chunks = [image_info[i:i + images_per_page] for i in range(0, len(image_info), images_per_page)]
if progress: progress(0, desc="Preparing PDF pages...")
for chunk_idx, chunk in enumerate(info_chunks):
page_width, page_height = A4_PORTRAIT_WIDTH, A4_PORTRAIT_HEIGHT
if orientation == "Landscape":
page_width, page_height = A4_PORTRAIT_HEIGHT, A4_PORTRAIT_WIDTH
elif orientation == "Auto":
total_aspect_ratio = sum(info['image'].width / info['image'].height for info in chunk)
avg_aspect_ratio = total_aspect_ratio / len(chunk) if chunk else 1
if avg_aspect_ratio > 1.1:
page_width, page_height = A4_PORTRAIT_HEIGHT, A4_PORTRAIT_WIDTH
page_canvas = Image.new('RGB', (page_width, page_height), 'white')
draw = ImageDraw.Draw(page_canvas)
num_images_on_page = len(chunk)
cols = int(math.ceil(math.sqrt(num_images_on_page)))
rows = int(math.ceil(num_images_on_page / cols))
margin, gutter, header_space = 150, 60, 140 # Reduced header space
cell_width = (page_width - 2 * margin - (cols - 1) * gutter) // cols
cell_height = (page_height - 2 * margin - (rows - 1) * gutter) // rows
for i, info in enumerate(chunk):
col, row = i % cols, i // rows
cell_x = margin + col * (cell_width + gutter)
cell_y = margin + row * (cell_height + gutter)
img = info['image']
draw.text((cell_x + 15, cell_y + 10), f"Q.No: {info.get('Question Number', 'N/A')}", fill="black", font=font_large)
info_y_offset = 60
for key, value in info.items():
if key not in {'image', 'Question Number', 'Include'} and value and str(value).strip():
display_text = f"{key.replace('_', ' ').title()}: {str(value)[:40]}"
draw.text((cell_x + 15, cell_y + info_y_offset), display_text, fill="dimgray", font=font_small)
info_y_offset += 35
img_area_width = cell_width
img_area_height = cell_height - header_space
# Resize image based on the user's specified algorithm
original_width, original_height = img.size
if original_width > 0 and original_height > 0:
ratio_w = img_area_width / original_width
ratio_h = img_area_height / original_height
smaller_ratio = min(ratio_w, ratio_h)
new_width = int(original_width * smaller_ratio)
new_height = int(original_height * smaller_ratio)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Align to top-left of the image area
paste_x = cell_x
paste_y = cell_y + header_space
page_canvas.paste(img, (paste_x, paste_y))
pages.append(page_canvas)
if progress:
progress((chunk_idx + 1) / len(info_chunks), desc=f"Generated page {chunk_idx + 1}/{len(info_chunks)}")
if pages:
if progress: progress(1, desc="Saving PDF...")
pages[0].save(output_path, "PDF", resolution=300.0, save_all=True, append_images=pages[1:])
return output_path
return None
# --- Main Gradio Function ---
def question_extractor_app(pdf_file, image_file, split_page_toggle, page_selection_str, progress=gr.Progress()):
if pdf_file and image_file:
raise gr.Error("Please upload either a PDF or an Image, not both.")
input_file = pdf_file or image_file
if not input_file:
raise gr.Error("Please upload a file.")
page_data_for_processing = []
if input_file.name.lower().endswith('.pdf'):
doc = fitz.open(input_file.name)
selected_pages = parse_page_ranges(page_selection_str)
page_indices = [p - 1 for p in selected_pages] if selected_pages else range(len(doc))
for i, page_num in enumerate(page_indices):
if not (0 <= page_num < len(doc)): continue
page = doc.load_page(page_num)
processed_successfully = False
for dpi in [300, 150]:
progress((i + 0.5) / len(page_indices), desc=f"Page {page_num + 1} at {dpi} DPI")
try:
pix = page.get_pixmap(dpi=dpi)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
resized_img = resize_image_if_needed(img)
with io.BytesIO() as buf:
resized_img.save(buf, format='PNG')
api_response = call_parse_api_base64(buf.getvalue())
page_data_for_processing.append((resized_img, api_response))
processed_successfully = True
break
except gr.Error as e:
if "maximum context length" in str(e) and dpi == 300:
print(f"Warning: Page {page_num + 1} too dense at 300 DPI. Retrying at 150 DPI.")
continue
else: raise e
if not processed_successfully:
raise gr.Error(f"Failed to process page {page_num + 1} even at lower resolutions.")
else:
img = Image.open(input_file.name).convert("RGB")
resized_img = resize_image_if_needed(img)
with io.BytesIO() as buf:
resized_img.save(buf, format='PNG')
api_response = call_parse_api_base64(buf.getvalue())
page_data_for_processing.append((resized_img, api_response))
if not page_data_for_processing:
return [], [], "No pages selected or file is empty.", "", "", "", pd.DataFrame(), gr.Group(visible=False), gr.Dropdown(choices=[])
all_processed_pages, all_gallery_images, all_question_data = [], [], []
for resized_img, api_response in page_data_for_processing:
boxed_img, page_gallery, page_q_data, _ = process_and_crop(resized_img, api_response, split_page_toggle)
all_processed_pages.append(boxed_img)
all_gallery_images.extend(page_gallery)
all_question_data.extend(page_q_data)
summary = f"Processed {len(page_data_for_processing)} page(s) and found {len(all_question_data)} questions."
session_id = str(time.time()).replace('.', '')
PROCESSED_PAGES_STORE[session_id] = all_processed_pages
CROPPED_QUESTIONS_STORE[session_id] = all_question_data
pages_info = f"Available: {', '.join(str(i+1) for i in range(len(all_processed_pages)))}"
questions_info = f"Available: {', '.join(str(item[0]) for item in all_question_data)}"
report_df = pd.DataFrame({
"Include": [True] * len(all_question_data),
"Question Number": [item[0] for item in all_question_data],
"Subject": ["" for _ in all_question_data],
"Topic": ["" for _ in all_question_data],
"Difficulty": pd.Categorical([""] * len(all_question_data), categories=["", "Easy", "Medium", "Hard"]),
"Status": pd.Categorical([""] * len(all_question_data), categories=["", "Correct", "Wrong", "Unattempted"])
})
column_choices = report_df.columns.tolist()
return (
all_processed_pages, all_gallery_images, summary, session_id,
pages_info, questions_info, report_df, gr.Group(visible=True),
gr.Dropdown(choices=column_choices, interactive=True)
)
def generate_report_pdf(session_id: str, report_df: pd.DataFrame, pdf_name: str, images_per_page: int, orientation: str, progress=gr.Progress(track_tqdm=True)):
if session_id not in CROPPED_QUESTIONS_STORE:
raise gr.Error("Session expired or invalid. Please re-process the files.")
selected_rows = report_df[report_df['Include']].to_dict('records')
if not selected_rows:
raise gr.Error("No questions selected to include in the report.")
all_questions = {q[0]: q[1] for q in CROPPED_QUESTIONS_STORE[session_id]}
image_info_for_pdf = []
for row in selected_rows:
q_num = row['Question Number']
if q_num in all_questions:
info = row.copy()
info['image'] = all_questions[q_num]
image_info_for_pdf.append(info)
pdf_filename_base = re.sub(r'[^\w-]', '_', pdf_name) if pdf_name else "Question_Report"
pdf_path = create_a4_pdf_from_images(image_info_for_pdf, int(images_per_page), pdf_filename_base, orientation, progress)
if pdf_path:
return gr.File(value=pdf_path, label="Download PDF Report")
else:
raise gr.Error("Failed to generate PDF. No pages were created.")
# --- Gradio UI Layout ---
if __name__ == "__main__":
with gr.Blocks(title="NIM Question Extractor", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📄 NVIDIA NIM Question Extractor & Report Generator")
gr.Markdown("Extract questions, add custom metadata, and generate an optimized PDF report.")
session_id_state = gr.State()
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("## 1. Input & Options")
pdf_input = gr.File(label="Upload PDF File", file_types=['.pdf'])
image_input = gr.File(label="Upload Image File", file_types=['.png', '.jpg', '.jpeg'])
page_select_input = gr.Textbox(label="Select Pages (PDF only)", placeholder="e.g., 1, 3, 5-10")
split_toggle = gr.Checkbox(label="Two-Column Layout")
submit_btn = gr.Button("🚀 Start Question Extraction", variant="primary")
with gr.Group():
gr.Markdown("## 2. Download Raw Images")
with gr.Accordion("Download ZIP Files", open=False):
download_pages_info = gr.Textbox(label="Available Pages", interactive=False)
download_pages_input = gr.Textbox(label="Select Pages to ZIP", placeholder="Leave blank for all")
download_pages_btn = gr.DownloadButton("📥 Pages ZIP", variant="secondary", interactive=False)
download_questions_info = gr.Textbox(label="Available Questions", interactive=False)
download_questions_input = gr.Textbox(label="Select Questions to ZIP", placeholder="Leave blank for all")
download_questions_btn = gr.DownloadButton("📥 Questions ZIP", variant="secondary", interactive=False)
with gr.Column(scale=2):
gr.Markdown("## 3. Review Extraction")
output_summary = gr.Textbox(label="Processing Summary", interactive=False)
with gr.Tab("Processed Pages (with boxes)"):
output_processed_pages = gr.Gallery(label="Pages with Boundaries", height=400, columns=2, object_fit="contain", show_label=False)
with gr.Tab("Individual Questions"):
output_cropped_gallery = gr.Gallery(label="Cropped Questions", height=400, columns=4, object_fit="contain", show_label=False)
with gr.Group(visible=False) as report_group:
gr.Markdown("--- \n ## 4. Create PDF Report")
gr.Markdown("Edit the table below to add metadata. Uncheck 'Include' to exclude a question from the report.")
with gr.Accordion("Bulk Edit Tools", open=False):
with gr.Row():
select_all_btn = gr.Button("Select All")
deselect_all_btn = gr.Button("Deselect All")
with gr.Row():
column_select_dropdown = gr.Dropdown(label="Select Column", interactive=False)
value_to_apply_input = gr.Textbox(label="Value to Apply", placeholder="e.g., Physics")
apply_to_col_btn = gr.Button("Apply Value to Column")
with gr.Row():
new_col_name_input = gr.Textbox(label="Custom Column Name", placeholder="e.g., Source Book")
add_col_btn = gr.Button("Add Column")
report_metadata_df = gr.DataFrame(
headers=["Include", "Question Number", "Subject", "Topic", "Difficulty", "Status"],
datatype=["bool", "number", "str", "str", "categorical", "categorical"],
interactive=True
)
with gr.Accordion("PDF Layout Options", open=True):
with gr.Row():
pdf_name_input = gr.Textbox("Question_Report", label="PDF Filename", scale=2)
images_per_page_input = gr.Slider(1, 16, value=4, step=1, label="Images Per Page", scale=2)
orientation_radio = gr.Radio(["Auto", "Portrait", "Landscape"], label="Page Orientation", value="Auto", scale=1)
generate_pdf_btn = gr.Button("📄 Generate PDF Report", variant="primary")
pdf_output_file = gr.File(label="Download PDF Report", interactive=False)
# --- Event Handlers ---
def toggle_include_all(df, select_all_flag):
if not df.empty:
df['Include'] = select_all_flag
return df
def apply_value_to_column(df, col_name, value):
if col_name and col_name in df.columns and value is not None:
df[col_name] = value
return df
select_all_btn.click(
fn=lambda df: toggle_include_all(df, True),
inputs=[report_metadata_df],
outputs=[report_metadata_df]
)
deselect_all_btn.click(
fn=lambda df: toggle_include_all(df, False),
inputs=[report_metadata_df],
outputs=[report_metadata_df]
)
apply_to_col_btn.click(
fn=apply_value_to_column,
inputs=[report_metadata_df, column_select_dropdown, value_to_apply_input],
outputs=[report_metadata_df]
)
def add_custom_column(df, col_name):
if col_name and col_name not in df.columns and not df.empty:
df[col_name] = ""
# Return updated dataframe and update the choices for the dropdown
return df, gr.Dropdown(choices=df.columns.tolist(), interactive=True)
add_col_btn.click(
fn=add_custom_column,
inputs=[report_metadata_df, new_col_name_input],
outputs=[report_metadata_df, column_select_dropdown]
)
submit_btn.click(
fn=question_extractor_app,
inputs=[pdf_input, image_input, split_toggle, page_select_input],
outputs=[output_processed_pages, output_cropped_gallery, output_summary, session_id_state,
download_pages_info, download_questions_info, report_metadata_df, report_group, column_select_dropdown]
).then(
lambda: (gr.DownloadButton(interactive=True), gr.DownloadButton(interactive=True)),
outputs=[download_pages_btn, download_questions_btn]
)
download_pages_btn.click(
fn=zip_selected_pages, inputs=[download_pages_input, session_id_state], outputs=[download_pages_btn]
)
download_questions_btn.click(
fn=zip_selected_questions, inputs=[download_questions_input, session_id_state], outputs=[download_questions_btn]
)
generate_pdf_btn.click(
fn=generate_report_pdf,
inputs=[session_id_state, report_metadata_df, pdf_name_input, images_per_page_input, orientation_radio],
outputs=[pdf_output_file]
)
demo.launch(debug=True)