Spaces:
Paused
Paused
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import shutil | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| from io import BytesIO | |
| from typing import Optional | |
| from urllib.parse import urlparse | |
| import layoutparser as lp | |
| import openai | |
| import pytesseract | |
| import requests | |
| from dotenv import load_dotenv | |
| from pdf2image import convert_from_bytes | |
| from pydantic import BaseModel, ConfigDict | |
| from create_assistant import create_assistant | |
| load_dotenv() | |
| logging.basicConfig(handlers=[logging.StreamHandler()], level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Block(BaseModel): | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| block: lp.elements.base.BaseLayoutElement | |
| page_index: int | |
| class CaptionedBlock(Block): | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| caption: lp.elements.base.BaseLayoutElement | |
| def get_blocks_and_texts(layouts: list[lp.Layout]) -> tuple[list[Block], list[Block]]: | |
| blocks = [] | |
| texts = [] | |
| for i, layout in enumerate(layouts): | |
| for block in layout: | |
| if block.type in ["Table", "Figure"]: | |
| # Check if the current block overlaps with any existing block | |
| for existing_block in blocks: | |
| if existing_block.page_index != i: | |
| # If the blocks are not on the same page, skip the overlap check | |
| continue | |
| overlap_area = existing_block.block.intersect(block).area | |
| overlap_ratio = overlap_area / block.area | |
| if overlap_ratio > 0.5: | |
| # If the current block overlaps with an existing block by more than 50% | |
| # Check which block is the "superset" block | |
| if block.area > existing_block.block.area: | |
| # If the current block is larger, replace the existing block with the current block | |
| blocks.remove(existing_block) | |
| blocks.append(Block(block=block, page_index=i)) | |
| # If the existing block is larger or equal, skip the current block | |
| break | |
| else: | |
| # If the current block does not overlap significantly with any existing block, add it to the list | |
| blocks.append(Block(block=block, page_index=i)) | |
| elif block.type == "Text": | |
| texts.append(Block(block=block, page_index=i)) | |
| return blocks, texts | |
| def caption_blocks(blocks: list[Block], texts: list[Block]) -> list[CaptionedBlock]: | |
| captioned_blocks = [] | |
| # Find the closest text block to the top and bottom of the figure/table block | |
| for block in blocks: | |
| block_bottom_center = ( | |
| (block.block.block.x_1 + block.block.block.x_2) / 2, | |
| block.block.block.y_2, | |
| ) | |
| block_top_center = ( | |
| (block.block.block.x_1 + block.block.block.x_2) / 2, | |
| block.block.block.y_1, | |
| ) | |
| closest_text = None | |
| closest_distance = float("inf") | |
| for text in texts: | |
| if text.page_index != block.page_index: | |
| continue | |
| text_top_center = ( | |
| (text.block.block.x_1 + text.block.block.x_2) / 2, | |
| text.block.block.y_1, | |
| ) | |
| text_bottom_center = ( | |
| (text.block.block.x_1 + text.block.block.x_2) / 2, | |
| text.block.block.y_2, | |
| ) | |
| distance_to_top = ( | |
| (block_bottom_center[0] - text_top_center[0]) ** 2 | |
| + (block_bottom_center[1] - text_top_center[1]) ** 2 | |
| ) ** 0.5 | |
| distance_to_bottom = ( | |
| (block_top_center[0] - text_bottom_center[0]) ** 2 | |
| + (block_top_center[1] - text_bottom_center[1]) ** 2 | |
| ) ** 0.5 | |
| # Reduce `distance_to_top` by 25% to bias towards picking bottom captions | |
| distance = min(distance_to_top * 0.75, distance_to_bottom) | |
| if distance < closest_distance: | |
| closest_distance = distance | |
| closest_text = text | |
| if closest_text is not None: | |
| captioned_blocks.append( | |
| CaptionedBlock( | |
| block=block.block, | |
| caption=closest_text.block, | |
| page_index=block.page_index, | |
| ) | |
| ) | |
| return captioned_blocks | |
| def combine_blocks(captioned_block, pages): | |
| # Combine block and caption together | |
| x_1 = min(captioned_block.block.block.x_1, captioned_block.caption.block.x_1) | |
| y_1 = min(captioned_block.block.block.y_1, captioned_block.caption.block.y_1) | |
| x_2 = max(captioned_block.block.block.x_2, captioned_block.caption.block.x_2) | |
| y_2 = max(captioned_block.block.block.y_2, captioned_block.caption.block.y_2) | |
| return pages[captioned_block.page_index].crop((x_1, y_1, x_2, y_2)) | |
| def process_captioned_block(captioned_block, pages, base_path): | |
| combined_image = combine_blocks(captioned_block, pages) | |
| # Convert the PIL Image object to base64 | |
| buffered = BytesIO() | |
| combined_image.save(buffered, format="JPEG") | |
| # Convert the PIL Image object to a string for caption | |
| caption_image = pages[captioned_block.page_index].crop( | |
| ( | |
| captioned_block.caption.block.x_1, | |
| captioned_block.caption.block.y_1, | |
| captioned_block.caption.block.x_2, | |
| captioned_block.caption.block.y_2, | |
| ) | |
| ) | |
| caption_text = pytesseract.image_to_string(caption_image) | |
| figures_path = os.path.join(base_path, "figures") | |
| os.makedirs(figures_path, exist_ok=True) | |
| # Convert the caption text to snake case alpha numeric and truncate, then add .jpg to it | |
| img_name = re.sub("[^0-9a-zA-Z]+", "_", caption_text)[:30] + ".jpg" | |
| img_path = os.path.join(figures_path, img_name) | |
| with open(img_path, "wb") as f: | |
| f.write(buffered.getvalue()) | |
| return {"image": f"figures/{img_name}", "caption": caption_text} | |
| def process_pdf(content: bytes, model: lp.models.Detectron2LayoutModel, base_path: str): | |
| pages = convert_from_bytes(content) | |
| logger.info("PDF converted to images") | |
| with ThreadPoolExecutor(max_workers=16) as executor: | |
| layouts = list(executor.map(model.detect, pages)) | |
| logger.info("Layout detection completed") | |
| blocks, texts = get_blocks_and_texts(layouts) | |
| logger.info("Blocks and texts extracted") | |
| captioned_blocks = caption_blocks(blocks, texts) | |
| logger.info("Captioning completed") | |
| with ThreadPoolExecutor(max_workers=16) as executor: | |
| results = list( | |
| executor.map( | |
| lambda captioned_block: process_captioned_block( | |
| captioned_block, pages, base_path | |
| ), | |
| captioned_blocks, | |
| ) | |
| ) | |
| return results | |
| def wait_on_run(run, thread, client: openai.OpenAI): | |
| while run.status == "queued" or run.status == "in_progress": | |
| run = client.beta.threads.runs.retrieve( | |
| thread_id=thread.id, | |
| run_id=run.id, | |
| ) | |
| time.sleep(0.5) | |
| return run | |
| def generate_thread_content( | |
| pdf_path: str, results: dict, client: openai.OpenAI, assistant_id: str | |
| ): | |
| with open(pdf_path, "rb") as f: | |
| pdf_file = client.files.create(file=f, purpose="assistants") | |
| try: | |
| thread = client.beta.threads.create() | |
| message = client.beta.threads.messages.create( | |
| thread_id=thread.id, | |
| role="user", | |
| content=f"{json.dumps(results)}\n\nCreate a thread for this. Your answer must be in JSON, media links should be from the local paths above.", | |
| file_ids=[pdf_file.id], | |
| ) | |
| run = client.beta.threads.runs.create( | |
| thread_id=thread.id, assistant_id=assistant_id | |
| ) | |
| run = wait_on_run(run, thread, client) | |
| messages = client.beta.threads.messages.list( | |
| thread_id=thread.id, order="asc", after=message.id | |
| ) | |
| # TODO: OpenAI can return no new messages somehow (might be a bug, the run completes succesfully but no new messages are listed in the thread), catch this and throw error | |
| if not messages.data or not messages.data[0].content: | |
| raise ValueError("Unexpected empty response from OpenAI. Please try again.") | |
| except Exception as e: | |
| logger.error(f"Failed to generate thread content: {e}") | |
| raise | |
| finally: | |
| # Delete uploaded PDF file | |
| try: | |
| client.files.delete(file_id=pdf_file.id) | |
| except Exception as e: | |
| logger.error(f"Failed to delete file: {e}") | |
| # Extract JSON content from the message | |
| message_content = messages.data[0].content[0].text.value | |
| json_content = re.search(r"(```json\n)(.*?)(\n```)", message_content, re.DOTALL) | |
| if json_content is None: | |
| json_content = re.search(r"(```\n)(.*?)(\n```)", message_content, re.DOTALL) | |
| if json_content is not None: | |
| json_content = json_content.group(2) | |
| try: | |
| paper_thread = json.loads(json_content) | |
| except (json.JSONDecodeError, TypeError): | |
| raise ValueError( | |
| "The thread generated by OpenAI was not in the expected JSON format." | |
| ) | |
| return paper_thread | |
| def process_thread(thread_data, base_path): | |
| processed_data = [] | |
| media_set = set() | |
| for data in thread_data: | |
| cleaned_content = re.sub( | |
| r"【\d+†source】", "", data["content"] | |
| ) # Remove all source annotations | |
| media_list = [] | |
| for media in data.get("media", []): | |
| if media["path"] and media["path"] not in media_set: | |
| media_file_path = os.path.join(base_path, media["path"]) | |
| if os.path.isfile(media_file_path): | |
| media_list.append(media) | |
| media_set.add(media["path"]) | |
| processed_data.append({"content": cleaned_content, "media": media_list}) | |
| return processed_data | |
| def render_markdown(processed_thread): | |
| markdown_content = "" | |
| for data in processed_thread: | |
| markdown_content += data["content"] + "\n" | |
| for media in data["media"]: | |
| markdown_content += f'\n<div align="center">\n' | |
| markdown_content += f' <img src="{media["path"]}" alt="{media.get("explain", "")}" style="max-width: 75%;">\n' | |
| markdown_content += "</div>\n" | |
| markdown_content += "\n---\n\n" | |
| return markdown_content | |
| def uri_validator(x): | |
| try: | |
| result = urlparse(x) | |
| return all([result.scheme, result.netloc]) | |
| except: | |
| return False | |
| def create_thread( | |
| pdf_url_or_path: str, output_path: str, client: openai.OpenAI, assistant_id: str | |
| ): | |
| # Extract the PDF name from the URL and remove any file extension at the end | |
| pdf_name = os.path.splitext(pdf_url_or_path.split("/")[-1])[0] | |
| base_path = os.path.join(output_path, pdf_name) | |
| results_path = os.path.join(base_path, "results.json") | |
| pdf_path = os.path.join(base_path, f"{pdf_name}.pdf") | |
| thread_path = os.path.join(base_path, "thread.json") | |
| processed_thread_path = os.path.join(base_path, "processed_thread.json") | |
| markdown_path = os.path.join(base_path, "processed_thread.md") | |
| # Check if base path already exists and there is a results.json | |
| # If so, assume we've run this before and just return results | |
| if os.path.exists(base_path) and os.path.isfile(results_path): | |
| with open(results_path, "r") as f: | |
| results = json.load(f) | |
| else: | |
| os.makedirs(base_path, exist_ok=True) | |
| if uri_validator(pdf_url_or_path): | |
| pdf_content = requests.get(pdf_url_or_path).content | |
| with open(pdf_path, "wb") as f: | |
| f.write(pdf_content) | |
| elif os.path.isfile(pdf_url_or_path): | |
| shutil.copy(pdf_url_or_path, pdf_path) | |
| with open(pdf_path, "rb") as f: | |
| pdf_content = f.read() | |
| else: | |
| raise ValueError( | |
| f"Invalid input: {pdf_url_or_path}. It should be a valid URL or a file path." | |
| ) | |
| model = lp.models.Detectron2LayoutModel( | |
| config_path="lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config", | |
| extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], | |
| label_map={0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}, | |
| ) | |
| results = process_pdf(pdf_content, model, base_path) | |
| # Remove duplicates from results | |
| results = [dict(t) for t in set(tuple(d.items()) for d in results)] | |
| with open(results_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| paper_thread = generate_thread_content(pdf_path, results, client, assistant_id) | |
| with open(thread_path, "w") as f: | |
| json.dump(paper_thread, f, indent=2) | |
| # Process the thread | |
| processed_thread = process_thread(paper_thread, base_path) | |
| with open(processed_thread_path, "w") as f: | |
| json.dump(processed_thread, f, indent=2) | |
| # Save processed thread as a markdown file | |
| markdown_content = render_markdown(processed_thread) | |
| with open(markdown_path, "w") as f: | |
| f.write(markdown_content) | |
| logger.info(f"Saved all outputs to: {os.path.abspath(base_path)}") | |
| return base_path | |
| def create_assistant_then_thread( | |
| pdf_url_or_path: str, | |
| output_path: str, | |
| client: openai.OpenAI, | |
| assistant_kwargs: Optional[dict] = None, | |
| ): | |
| if assistant_kwargs is None: | |
| assistant_kwargs = {} | |
| try: | |
| assistant = create_assistant(client, **assistant_kwargs) | |
| except Exception: | |
| logger.error("Failed to create assistant", exc_info=True) | |
| raise | |
| try: | |
| saved_path = create_thread( | |
| pdf_url_or_path, | |
| output_path, | |
| client, | |
| assistant.id, | |
| ) | |
| except Exception: | |
| logger.error("Failed to create thread", exc_info=True) | |
| raise | |
| finally: | |
| try: | |
| client.beta.assistants.delete(assistant.id) | |
| except Exception: | |
| logger.error("Failed to delete assistant", exc_info=True) | |
| raise | |
| return saved_path | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Process a PDF from a URL or a local path." | |
| ) | |
| parser.add_argument( | |
| "url_or_path", type=str, help="The URL or local path of the PDF to process." | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| default="data", | |
| help="The output directory to store the results.", | |
| ) | |
| args = parser.parse_args() | |
| client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
| create_assistant_then_thread(args.url_or_path, args.output, client) | |