Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import json | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import io | |
| import base64 | |
| import re | |
| import fitz | |
| import zipfile | |
| import tempfile | |
| import time | |
| import math | |
| from datetime import datetime | |
| import pandas as pd | |
| # --- Configuration --- | |
| NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY") | |
| if not NVIDIA_API_KEY: | |
| raise ValueError("NVIDIA_API_KEY environment variable not set.") | |
| NIM_API_URL = "https://integrate.api.nvidia.com/v1/chat/completions" | |
| HEADERS = { | |
| "Authorization": f"Bearer {NVIDIA_API_KEY}", | |
| "Accept": "application/json", | |
| "Content-Type": "application/json", | |
| } | |
| MODEL_MAX_WIDTH = 1648 | |
| MODEL_MAX_HEIGHT = 2048 | |
| # --- Folder Setup for PDF Output --- | |
| OUTPUT_FOLDER = 'output_reports' | |
| os.makedirs(OUTPUT_FOLDER, exist_ok=True) | |
| # Global store for processed data (key is session_id) | |
| PROCESSED_PAGES_STORE = {} | |
| CROPPED_QUESTIONS_STORE = {} | |
| # --- Helper Functions (Image Processing, API Calls) --- | |
| def resize_image_if_needed(image: Image.Image) -> Image.Image: | |
| width, height = image.size | |
| if width > MODEL_MAX_WIDTH or height > MODEL_MAX_HEIGHT: | |
| ratio = min(MODEL_MAX_WIDTH / width, MODEL_MAX_HEIGHT / height) | |
| new_width = int(width * ratio) | |
| new_height = int(height * ratio) | |
| return image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| return image | |
| def call_parse_api_base64(image_bytes: bytes): | |
| try: | |
| base64_encoded_data = base64.b64encode(image_bytes) | |
| base64_string = base64_encoded_data.decode('utf-8') | |
| image_url = f"data:image/png;base64,{base64_string}" | |
| payload = { | |
| "model": "nvidia/nemoretriever-parse", | |
| "messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": image_url}}]}], | |
| "tools": [{"type": "function", "function": {"name": "markdown_bbox"}}], | |
| "max_tokens": 2048, | |
| } | |
| response = requests.post(NIM_API_URL, headers=HEADERS, json=payload, timeout=300) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| error_detail = str(e) | |
| if e.response is not None: | |
| try: | |
| error_detail = e.response.json().get("detail", e.response.text) | |
| except json.JSONDecodeError: | |
| error_detail = e.response.text | |
| if "maximum context length" in str(error_detail): | |
| raise gr.Error(f"API Error: The document page is too dense for the model's context limit. Details: {error_detail}") | |
| raise gr.Error(f"API Error: {error_detail}") | |
| def get_question_number(text: str) -> int: | |
| match = re.match(r"^\d+", text.strip()) | |
| return int(match.group(0)) if match else -1 | |
| def parse_page_ranges(range_str: str) -> set: | |
| if not range_str: return set() | |
| pages = set() | |
| parts = range_str.split(',') | |
| for part in parts: | |
| part = part.strip() | |
| if not part: continue | |
| try: | |
| if '-' in part: | |
| start, end = map(int, part.split('-')) | |
| if start > end: continue | |
| pages.update(range(start, end + 1)) | |
| else: | |
| pages.add(int(part)) | |
| except ValueError: | |
| continue | |
| return pages | |
| # --- Core Cropping Logic --- | |
| def process_and_crop(original_image: Image.Image, api_response: dict, split_page: bool): | |
| try: | |
| tool_call = api_response["choices"][0]["message"]["tool_calls"][0] | |
| arguments_str = tool_call["function"]["arguments"] | |
| all_elements = json.loads(arguments_str)[0] | |
| except (KeyError, IndexError, json.JSONDecodeError): | |
| return original_image, [], [], 0 | |
| question_starts = [elem for elem in all_elements if get_question_number(elem.get("text", "")) > 0] | |
| if not question_starts: | |
| return original_image, [], [], 0 | |
| image_with_boxes = original_image.copy() | |
| img_draw = ImageDraw.Draw(image_with_boxes) | |
| all_cropped_questions = [] | |
| if split_page: | |
| page_midpoint = 0.5 | |
| left_starts = sorted([q for q in question_starts if q['bbox']['xmin'] < page_midpoint], key=lambda q: q['bbox']['ymin']) | |
| right_starts = sorted([q for q in question_starts if q['bbox']['xmin'] >= page_midpoint], key=lambda q: q['bbox']['ymin']) | |
| process_column(left_starts, all_elements, (0.0, page_midpoint), img_draw, original_image, all_cropped_questions) | |
| process_column(right_starts, all_elements, (page_midpoint, 1.0), img_draw, original_image, all_cropped_questions) | |
| else: | |
| sorted_starts = sorted(question_starts, key=lambda q: q['bbox']['ymin']) | |
| process_column(sorted_starts, all_elements, (0.0, 1.0), img_draw, original_image, all_cropped_questions) | |
| all_cropped_questions.sort(key=lambda item: item[0]) | |
| final_gallery_images = [item[1] for item in all_cropped_questions] | |
| return image_with_boxes, final_gallery_images, all_cropped_questions, len(all_cropped_questions) | |
| def process_column(column_starts, all_elements, column_bounds, img_draw, original_image, cropped_questions_list): | |
| img_width, img_height = original_image.size | |
| MIN_CROP_WIDTH, MIN_CROP_HEIGHT = 100, 50 | |
| for i, start_element in enumerate(column_starts): | |
| q_num = get_question_number(start_element['text']) | |
| slice_ymin = start_element['bbox']['ymin'] | |
| next_ymin = 1.0 | |
| if i + 1 < len(column_starts): | |
| next_ymin = column_starts[i+1]['bbox']['ymin'] | |
| elements_in_slice = [ | |
| e for e in all_elements if | |
| slice_ymin <= e['bbox']['ymin'] < next_ymin and | |
| column_bounds[0] <= e['bbox']['xmin'] < column_bounds[1] | |
| ] | |
| if not elements_in_slice: continue | |
| crop_xmin = min(e['bbox']['xmin'] for e in elements_in_slice) | |
| crop_xmax = max(e['bbox']['xmax'] for e in elements_in_slice) | |
| crop_ymax = max(e['bbox']['ymax'] for e in elements_in_slice) | |
| abs_box = (crop_xmin * img_width, slice_ymin * img_height, crop_xmax * img_width, crop_ymax * img_height) | |
| if (abs_box[2] - abs_box[0]) < MIN_CROP_WIDTH or (abs_box[3] - abs_box[1]) < MIN_CROP_HEIGHT: | |
| continue | |
| img_draw.rectangle(abs_box, outline="red", width=3) | |
| cropped_img = original_image.crop(abs_box) | |
| question_text = start_element.get('text', '').strip() | |
| clean_text = re.sub(r'[^\w\s-]', '', question_text)[:50].strip() | |
| clean_text = re.sub(r'\s+', '_', clean_text) | |
| filename = f"{q_num}-{clean_text}" if clean_text else f"Q_{q_num}" | |
| cropped_questions_list.append((q_num, cropped_img, filename)) | |
| # --- ZIP Download Functions --- | |
| def zip_selected_questions(selected_indices_str: str, session_id: str): | |
| if session_id not in CROPPED_QUESTIONS_STORE: | |
| raise gr.Error("No processed questions found.") | |
| cropped_questions = CROPPED_QUESTIONS_STORE[session_id] | |
| if not cropped_questions: | |
| raise gr.Error("No questions were extracted.") | |
| selected_indices = parse_page_ranges(selected_indices_str) if selected_indices_str.strip() else {item[0] for item in cropped_questions} | |
| if not selected_indices: | |
| raise gr.Error("Please enter valid question numbers/ranges.") | |
| zip_path = os.path.join(tempfile.gettempdir(), f"questions_{session_id}.zip") | |
| with zipfile.ZipFile(zip_path, 'w') as zf: | |
| for q_num, img, filename in cropped_questions: | |
| if q_num in selected_indices: | |
| img_io = io.BytesIO() | |
| img.save(img_io, format='PNG') | |
| zf.writestr(f"{filename}.png", img_io.getvalue()) | |
| return zip_path | |
| def zip_selected_pages(selected_indices_str: str, session_id: str): | |
| if session_id not in PROCESSED_PAGES_STORE: | |
| raise gr.Error("No processed results found.") | |
| processed_pages = PROCESSED_PAGES_STORE[session_id] | |
| if not processed_pages: | |
| raise gr.Error("No pages were processed.") | |
| selected_indices = parse_page_ranges(selected_indices_str) if selected_indices_str.strip() else set(range(1, len(processed_pages) + 1)) | |
| if not selected_indices: | |
| raise gr.Error("Please enter valid page numbers/ranges.") | |
| zip_path = os.path.join(tempfile.gettempdir(), f"pages_{session_id}.zip") | |
| with zipfile.ZipFile(zip_path, 'w') as zf: | |
| for user_page_num in selected_indices: | |
| list_index = user_page_num - 1 | |
| if 0 <= list_index < len(processed_pages): | |
| img = processed_pages[list_index] | |
| img_io = io.BytesIO() | |
| img.save(img_io, format='PNG') | |
| zf.writestr(f"Page_{user_page_num}_boxed.png", img_io.getvalue()) | |
| return zip_path | |
| # --- PDF Generation Functions (Integrated) --- | |
| def get_or_download_font(font_path="arial.ttf", font_size=50): | |
| if not os.path.exists(font_path): | |
| try: | |
| print("Downloading arial.ttf font...") | |
| response = requests.get("https://github.com/matomo-org/travis-scripts/raw/master/fonts/arial.ttf", timeout=30) | |
| response.raise_for_status() | |
| with open(font_path, 'wb') as f: f.write(response.content) | |
| print("Font downloaded.") | |
| except Exception as e: | |
| print(f"Font download failed: {e}. Using default font.") | |
| return ImageFont.load_default() | |
| try: | |
| return ImageFont.truetype(font_path, size=font_size) | |
| except IOError: | |
| print("Arial font not found or failed to load. Using default font.") | |
| return ImageFont.load_default() | |
| def create_a4_pdf_from_images(image_info, images_per_page, pdf_filename_base, orientation="Auto", progress=None): | |
| if not image_info: return None | |
| A4_PORTRAIT_WIDTH, A4_PORTRAIT_HEIGHT = 2480, 3508 | |
| font_large, font_small = get_or_download_font(font_size=40), get_or_download_font(font_size=28) | |
| output_filename = f"{pdf_filename_base}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf" | |
| output_path = os.path.join(OUTPUT_FOLDER, output_filename) | |
| pages = [] | |
| info_chunks = [image_info[i:i + images_per_page] for i in range(0, len(image_info), images_per_page)] | |
| if progress: progress(0, desc="Preparing PDF pages...") | |
| for chunk_idx, chunk in enumerate(info_chunks): | |
| page_width, page_height = A4_PORTRAIT_WIDTH, A4_PORTRAIT_HEIGHT | |
| if orientation == "Landscape": | |
| page_width, page_height = A4_PORTRAIT_HEIGHT, A4_PORTRAIT_WIDTH | |
| elif orientation == "Auto": | |
| total_aspect_ratio = sum(info['image'].width / info['image'].height for info in chunk) | |
| avg_aspect_ratio = total_aspect_ratio / len(chunk) if chunk else 1 | |
| if avg_aspect_ratio > 1.1: | |
| page_width, page_height = A4_PORTRAIT_HEIGHT, A4_PORTRAIT_WIDTH | |
| page_canvas = Image.new('RGB', (page_width, page_height), 'white') | |
| draw = ImageDraw.Draw(page_canvas) | |
| num_images_on_page = len(chunk) | |
| cols = int(math.ceil(math.sqrt(num_images_on_page))) | |
| rows = int(math.ceil(num_images_on_page / cols)) | |
| margin, gutter, header_space = 150, 60, 140 # Reduced header space | |
| cell_width = (page_width - 2 * margin - (cols - 1) * gutter) // cols | |
| cell_height = (page_height - 2 * margin - (rows - 1) * gutter) // rows | |
| for i, info in enumerate(chunk): | |
| col, row = i % cols, i // rows | |
| cell_x = margin + col * (cell_width + gutter) | |
| cell_y = margin + row * (cell_height + gutter) | |
| img = info['image'] | |
| draw.text((cell_x + 15, cell_y + 10), f"Q.No: {info.get('Question Number', 'N/A')}", fill="black", font=font_large) | |
| info_y_offset = 60 | |
| for key, value in info.items(): | |
| if key not in {'image', 'Question Number', 'Include'} and value and str(value).strip(): | |
| display_text = f"{key.replace('_', ' ').title()}: {str(value)[:40]}" | |
| draw.text((cell_x + 15, cell_y + info_y_offset), display_text, fill="dimgray", font=font_small) | |
| info_y_offset += 35 | |
| img_area_width = cell_width | |
| img_area_height = cell_height - header_space | |
| # Resize image based on the user's specified algorithm | |
| original_width, original_height = img.size | |
| if original_width > 0 and original_height > 0: | |
| ratio_w = img_area_width / original_width | |
| ratio_h = img_area_height / original_height | |
| smaller_ratio = min(ratio_w, ratio_h) | |
| new_width = int(original_width * smaller_ratio) | |
| new_height = int(original_height * smaller_ratio) | |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Align to top-left of the image area | |
| paste_x = cell_x | |
| paste_y = cell_y + header_space | |
| page_canvas.paste(img, (paste_x, paste_y)) | |
| pages.append(page_canvas) | |
| if progress: | |
| progress((chunk_idx + 1) / len(info_chunks), desc=f"Generated page {chunk_idx + 1}/{len(info_chunks)}") | |
| if pages: | |
| if progress: progress(1, desc="Saving PDF...") | |
| pages[0].save(output_path, "PDF", resolution=300.0, save_all=True, append_images=pages[1:]) | |
| return output_path | |
| return None | |
| # --- Main Gradio Function --- | |
| def question_extractor_app(pdf_file, image_file, split_page_toggle, page_selection_str, progress=gr.Progress()): | |
| if pdf_file and image_file: | |
| raise gr.Error("Please upload either a PDF or an Image, not both.") | |
| input_file = pdf_file or image_file | |
| if not input_file: | |
| raise gr.Error("Please upload a file.") | |
| page_data_for_processing = [] | |
| if input_file.name.lower().endswith('.pdf'): | |
| doc = fitz.open(input_file.name) | |
| selected_pages = parse_page_ranges(page_selection_str) | |
| page_indices = [p - 1 for p in selected_pages] if selected_pages else range(len(doc)) | |
| for i, page_num in enumerate(page_indices): | |
| if not (0 <= page_num < len(doc)): continue | |
| page = doc.load_page(page_num) | |
| processed_successfully = False | |
| for dpi in [300, 150]: | |
| progress((i + 0.5) / len(page_indices), desc=f"Page {page_num + 1} at {dpi} DPI") | |
| try: | |
| pix = page.get_pixmap(dpi=dpi) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| resized_img = resize_image_if_needed(img) | |
| with io.BytesIO() as buf: | |
| resized_img.save(buf, format='PNG') | |
| api_response = call_parse_api_base64(buf.getvalue()) | |
| page_data_for_processing.append((resized_img, api_response)) | |
| processed_successfully = True | |
| break | |
| except gr.Error as e: | |
| if "maximum context length" in str(e) and dpi == 300: | |
| print(f"Warning: Page {page_num + 1} too dense at 300 DPI. Retrying at 150 DPI.") | |
| continue | |
| else: raise e | |
| if not processed_successfully: | |
| raise gr.Error(f"Failed to process page {page_num + 1} even at lower resolutions.") | |
| else: | |
| img = Image.open(input_file.name).convert("RGB") | |
| resized_img = resize_image_if_needed(img) | |
| with io.BytesIO() as buf: | |
| resized_img.save(buf, format='PNG') | |
| api_response = call_parse_api_base64(buf.getvalue()) | |
| page_data_for_processing.append((resized_img, api_response)) | |
| if not page_data_for_processing: | |
| return [], [], "No pages selected or file is empty.", "", "", "", pd.DataFrame(), gr.Group(visible=False), gr.Dropdown(choices=[]) | |
| all_processed_pages, all_gallery_images, all_question_data = [], [], [] | |
| for resized_img, api_response in page_data_for_processing: | |
| boxed_img, page_gallery, page_q_data, _ = process_and_crop(resized_img, api_response, split_page_toggle) | |
| all_processed_pages.append(boxed_img) | |
| all_gallery_images.extend(page_gallery) | |
| all_question_data.extend(page_q_data) | |
| summary = f"Processed {len(page_data_for_processing)} page(s) and found {len(all_question_data)} questions." | |
| session_id = str(time.time()).replace('.', '') | |
| PROCESSED_PAGES_STORE[session_id] = all_processed_pages | |
| CROPPED_QUESTIONS_STORE[session_id] = all_question_data | |
| pages_info = f"Available: {', '.join(str(i+1) for i in range(len(all_processed_pages)))}" | |
| questions_info = f"Available: {', '.join(str(item[0]) for item in all_question_data)}" | |
| report_df = pd.DataFrame({ | |
| "Include": [True] * len(all_question_data), | |
| "Question Number": [item[0] for item in all_question_data], | |
| "Subject": ["" for _ in all_question_data], | |
| "Topic": ["" for _ in all_question_data], | |
| "Difficulty": pd.Categorical([""] * len(all_question_data), categories=["", "Easy", "Medium", "Hard"]), | |
| "Status": pd.Categorical([""] * len(all_question_data), categories=["", "Correct", "Wrong", "Unattempted"]) | |
| }) | |
| column_choices = report_df.columns.tolist() | |
| return ( | |
| all_processed_pages, all_gallery_images, summary, session_id, | |
| pages_info, questions_info, report_df, gr.Group(visible=True), | |
| gr.Dropdown(choices=column_choices, interactive=True) | |
| ) | |
| def generate_report_pdf(session_id: str, report_df: pd.DataFrame, pdf_name: str, images_per_page: int, orientation: str, progress=gr.Progress(track_tqdm=True)): | |
| if session_id not in CROPPED_QUESTIONS_STORE: | |
| raise gr.Error("Session expired or invalid. Please re-process the files.") | |
| selected_rows = report_df[report_df['Include']].to_dict('records') | |
| if not selected_rows: | |
| raise gr.Error("No questions selected to include in the report.") | |
| all_questions = {q[0]: q[1] for q in CROPPED_QUESTIONS_STORE[session_id]} | |
| image_info_for_pdf = [] | |
| for row in selected_rows: | |
| q_num = row['Question Number'] | |
| if q_num in all_questions: | |
| info = row.copy() | |
| info['image'] = all_questions[q_num] | |
| image_info_for_pdf.append(info) | |
| pdf_filename_base = re.sub(r'[^\w-]', '_', pdf_name) if pdf_name else "Question_Report" | |
| pdf_path = create_a4_pdf_from_images(image_info_for_pdf, int(images_per_page), pdf_filename_base, orientation, progress) | |
| if pdf_path: | |
| return gr.File(value=pdf_path, label="Download PDF Report") | |
| else: | |
| raise gr.Error("Failed to generate PDF. No pages were created.") | |
| # --- Gradio UI Layout --- | |
| if __name__ == "__main__": | |
| with gr.Blocks(title="NIM Question Extractor", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 📄 NVIDIA NIM Question Extractor & Report Generator") | |
| gr.Markdown("Extract questions, add custom metadata, and generate an optimized PDF report.") | |
| session_id_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("## 1. Input & Options") | |
| pdf_input = gr.File(label="Upload PDF File", file_types=['.pdf']) | |
| image_input = gr.File(label="Upload Image File", file_types=['.png', '.jpg', '.jpeg']) | |
| page_select_input = gr.Textbox(label="Select Pages (PDF only)", placeholder="e.g., 1, 3, 5-10") | |
| split_toggle = gr.Checkbox(label="Two-Column Layout") | |
| submit_btn = gr.Button("🚀 Start Question Extraction", variant="primary") | |
| with gr.Group(): | |
| gr.Markdown("## 2. Download Raw Images") | |
| with gr.Accordion("Download ZIP Files", open=False): | |
| download_pages_info = gr.Textbox(label="Available Pages", interactive=False) | |
| download_pages_input = gr.Textbox(label="Select Pages to ZIP", placeholder="Leave blank for all") | |
| download_pages_btn = gr.DownloadButton("📥 Pages ZIP", variant="secondary", interactive=False) | |
| download_questions_info = gr.Textbox(label="Available Questions", interactive=False) | |
| download_questions_input = gr.Textbox(label="Select Questions to ZIP", placeholder="Leave blank for all") | |
| download_questions_btn = gr.DownloadButton("📥 Questions ZIP", variant="secondary", interactive=False) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## 3. Review Extraction") | |
| output_summary = gr.Textbox(label="Processing Summary", interactive=False) | |
| with gr.Tab("Processed Pages (with boxes)"): | |
| output_processed_pages = gr.Gallery(label="Pages with Boundaries", height=400, columns=2, object_fit="contain", show_label=False) | |
| with gr.Tab("Individual Questions"): | |
| output_cropped_gallery = gr.Gallery(label="Cropped Questions", height=400, columns=4, object_fit="contain", show_label=False) | |
| with gr.Group(visible=False) as report_group: | |
| gr.Markdown("--- \n ## 4. Create PDF Report") | |
| gr.Markdown("Edit the table below to add metadata. Uncheck 'Include' to exclude a question from the report.") | |
| with gr.Accordion("Bulk Edit Tools", open=False): | |
| with gr.Row(): | |
| select_all_btn = gr.Button("Select All") | |
| deselect_all_btn = gr.Button("Deselect All") | |
| with gr.Row(): | |
| column_select_dropdown = gr.Dropdown(label="Select Column", interactive=False) | |
| value_to_apply_input = gr.Textbox(label="Value to Apply", placeholder="e.g., Physics") | |
| apply_to_col_btn = gr.Button("Apply Value to Column") | |
| with gr.Row(): | |
| new_col_name_input = gr.Textbox(label="Custom Column Name", placeholder="e.g., Source Book") | |
| add_col_btn = gr.Button("Add Column") | |
| report_metadata_df = gr.DataFrame( | |
| headers=["Include", "Question Number", "Subject", "Topic", "Difficulty", "Status"], | |
| datatype=["bool", "number", "str", "str", "categorical", "categorical"], | |
| interactive=True | |
| ) | |
| with gr.Accordion("PDF Layout Options", open=True): | |
| with gr.Row(): | |
| pdf_name_input = gr.Textbox("Question_Report", label="PDF Filename", scale=2) | |
| images_per_page_input = gr.Slider(1, 16, value=4, step=1, label="Images Per Page", scale=2) | |
| orientation_radio = gr.Radio(["Auto", "Portrait", "Landscape"], label="Page Orientation", value="Auto", scale=1) | |
| generate_pdf_btn = gr.Button("📄 Generate PDF Report", variant="primary") | |
| pdf_output_file = gr.File(label="Download PDF Report", interactive=False) | |
| # --- Event Handlers --- | |
| def toggle_include_all(df, select_all_flag): | |
| if not df.empty: | |
| df['Include'] = select_all_flag | |
| return df | |
| def apply_value_to_column(df, col_name, value): | |
| if col_name and col_name in df.columns and value is not None: | |
| df[col_name] = value | |
| return df | |
| select_all_btn.click( | |
| fn=lambda df: toggle_include_all(df, True), | |
| inputs=[report_metadata_df], | |
| outputs=[report_metadata_df] | |
| ) | |
| deselect_all_btn.click( | |
| fn=lambda df: toggle_include_all(df, False), | |
| inputs=[report_metadata_df], | |
| outputs=[report_metadata_df] | |
| ) | |
| apply_to_col_btn.click( | |
| fn=apply_value_to_column, | |
| inputs=[report_metadata_df, column_select_dropdown, value_to_apply_input], | |
| outputs=[report_metadata_df] | |
| ) | |
| def add_custom_column(df, col_name): | |
| if col_name and col_name not in df.columns and not df.empty: | |
| df[col_name] = "" | |
| # Return updated dataframe and update the choices for the dropdown | |
| return df, gr.Dropdown(choices=df.columns.tolist(), interactive=True) | |
| add_col_btn.click( | |
| fn=add_custom_column, | |
| inputs=[report_metadata_df, new_col_name_input], | |
| outputs=[report_metadata_df, column_select_dropdown] | |
| ) | |
| submit_btn.click( | |
| fn=question_extractor_app, | |
| inputs=[pdf_input, image_input, split_toggle, page_select_input], | |
| outputs=[output_processed_pages, output_cropped_gallery, output_summary, session_id_state, | |
| download_pages_info, download_questions_info, report_metadata_df, report_group, column_select_dropdown] | |
| ).then( | |
| lambda: (gr.DownloadButton(interactive=True), gr.DownloadButton(interactive=True)), | |
| outputs=[download_pages_btn, download_questions_btn] | |
| ) | |
| download_pages_btn.click( | |
| fn=zip_selected_pages, inputs=[download_pages_input, session_id_state], outputs=[download_pages_btn] | |
| ) | |
| download_questions_btn.click( | |
| fn=zip_selected_questions, inputs=[download_questions_input, session_id_state], outputs=[download_questions_btn] | |
| ) | |
| generate_pdf_btn.click( | |
| fn=generate_report_pdf, | |
| inputs=[session_id_state, report_metadata_df, pdf_name_input, images_per_page_input, orientation_radio], | |
| outputs=[pdf_output_file] | |
| ) | |
| demo.launch(debug=True) | |