Spaces:
Running
Running
| """ | |
| pdf text and asset extraction | |
| """ | |
| import json | |
| import random | |
| import re | |
| from pathlib import Path | |
| from typing import Dict, Any, Tuple | |
| from marker.converters.pdf import PdfConverter | |
| from marker.renderers.markdown import MarkdownRenderer | |
| from marker.models import create_model_dict | |
| from marker.output import text_from_rendered | |
| from marker.schema import BlockTypes | |
| from jinja2 import Template | |
| from src.state.poster_state import PosterState | |
| from utils.langgraph_utils import LangGraphAgent, extract_json, load_prompt | |
| from utils.src.logging_utils import log_agent_info, log_agent_success, log_agent_error, log_agent_warning | |
| from src.config.poster_config import load_config | |
| class Parser: | |
| def __init__(self): | |
| self.name = "parser" | |
| config_data = load_config() | |
| batch_config = config_data["pdf_processing"]["batch_sizes"] | |
| config = { | |
| "recognition_batch_size": batch_config["recognition"], | |
| "layout_batch_size": batch_config["layout"], | |
| "detection_batch_size": batch_config["detection"], | |
| "table_rec_batch_size": batch_config["table_rec"], | |
| "ocr_error_batch_size": batch_config["ocr_error"], | |
| "equation_batch_size": batch_config["equation"], | |
| "disable_tqdm": False, | |
| } | |
| self.converter = PdfConverter(artifact_dict=create_model_dict(), config=config) | |
| self.clean_pattern = re.compile(r"<!--[\s\S]*?-->") | |
| self.enhanced_abt_prompt = load_prompt("config/prompts/narrative_abt_extraction.txt") | |
| self.visual_classification_prompt = load_prompt("config/prompts/classify_visuals.txt") | |
| self.title_authors_prompt = load_prompt("config/prompts/extract_title_authors.txt") | |
| self.section_extraction_prompt = load_prompt("config/prompts/extract_structured_sections.txt") | |
| def __call__(self, state: PosterState) -> PosterState: | |
| log_agent_info(self.name, "starting foundation building") | |
| try: | |
| output_dir = Path(state["output_dir"]) | |
| content_dir = output_dir / "content" | |
| assets_dir = output_dir / "assets" | |
| content_dir.mkdir(parents=True, exist_ok=True) | |
| assets_dir.mkdir(parents=True, exist_ok=True) | |
| # extract raw text and assets | |
| raw_text, raw_result = self._extract_raw_text(state["pdf_path"], content_dir) | |
| figures, tables = self._extract_assets(raw_result, state["poster_name"], assets_dir) | |
| # extract title and authors from raw text | |
| title, authors = self._extract_title_authors(raw_text, state["text_model"]) | |
| # generate narrative content | |
| narrative_content, inp_tok, out_tok = self._generate_narrative_content(raw_text, state["text_model"]) | |
| state["tokens"].add_text(inp_tok, out_tok) | |
| # classify visual assets by importance | |
| classified_visuals, inp_tok2, out_tok2 = self._classify_visual_assets(figures, tables, raw_text, state["text_model"]) | |
| state["tokens"].add_text(inp_tok2, out_tok2) | |
| # narrative metadata | |
| narrative_content["meta"] = { | |
| "poster_title": title, | |
| "authors": authors | |
| } | |
| # extract structured sections from raw text | |
| structured_sections = self._extract_structured_sections(raw_text, state["text_model"]) | |
| # save artifacts and update state | |
| self._save_content(narrative_content, "narrative_content.json", content_dir) | |
| self._save_content(classified_visuals, "classified_visuals.json", content_dir) | |
| self._save_content(structured_sections, "structured_sections.json", content_dir) | |
| self._save_raw_text(raw_text, content_dir) | |
| state["raw_text"] = raw_text | |
| state["structured_sections"] = structured_sections | |
| state["narrative_content"] = narrative_content | |
| state["classified_visuals"] = classified_visuals | |
| state["images"] = figures | |
| state["tables"] = tables | |
| state["current_agent"] = self.name | |
| log_agent_success(self.name, f"extracted raw text, {len(figures)} images, and {len(tables)} tables") | |
| log_agent_success(self.name, f"extracted title: {title}") | |
| log_agent_success(self.name, "generated enhanced abt narrative") | |
| log_agent_success(self.name, f"classified visuals: key={classified_visuals.get('key_visual', 'none')}, problem_ill={len(classified_visuals.get('problem_illustration', []))}, method_wf={len(classified_visuals.get('method_workflow', []))}, main_res={len(classified_visuals.get('main_results', []))}, comp_res={len(classified_visuals.get('comparative_results', []))}, support={len(classified_visuals.get('supporting', []))}") | |
| except Exception as e: | |
| log_agent_error(self.name, f"failed: {e}") | |
| state["errors"].append(str(e)) | |
| return state | |
| def _extract_raw_text(self, pdf_path: str, content_dir: Path) -> Tuple[str, Any]: | |
| log_agent_info(self.name, "converting pdf to raw text") | |
| document = self.converter.build_document(pdf_path) | |
| # create renderer and get rendered output from the existing document | |
| renderer = self.converter.resolve_dependencies(MarkdownRenderer) | |
| rendered = renderer(document) | |
| text, _, images = text_from_rendered(rendered) | |
| text = self.clean_pattern.sub("", text) | |
| (content_dir / "raw.md").write_text(text, encoding="utf-8") | |
| log_agent_info(self.name, f"extracted {len(text)} chars") | |
| raw_result = (document, rendered, images) | |
| return text, raw_result | |
| def _generate_narrative_content(self, text: str, config) -> Tuple[Dict, int, int]: | |
| log_agent_info(self.name, "generating abt narrative") | |
| agent = LangGraphAgent("expert poster design consultant", config) | |
| for attempt in range(3): | |
| try: | |
| prompt = Template(self.enhanced_abt_prompt).render(markdown_document=text) | |
| agent.reset() | |
| response = agent.step(prompt) | |
| narrative = extract_json(response.content) | |
| if "and" in narrative and "but" in narrative and "therefore" in narrative: | |
| return narrative, response.input_tokens, response.output_tokens | |
| except Exception as e: | |
| log_agent_warning(self.name, f"attempt {attempt + 1} failed: {e}") | |
| if attempt == 2: | |
| raise | |
| raise ValueError("failed to generate enhanced narrative after 3 attempts") | |
| def _save_content(self, content: Dict, filename: str, content_dir: Path): | |
| with open(content_dir / filename, 'w', encoding='utf-8') as f: | |
| json.dump(content, f, indent=2) | |
| def _save_raw_text(self, raw_text: str, content_dir: Path): | |
| with open(content_dir / "raw.md", 'w', encoding='utf-8') as f: | |
| f.write(raw_text) | |
| def _extract_assets(self, result, name: str, assets_dir: Path) -> Tuple[Dict, Dict]: | |
| log_agent_info(self.name, "extracting assets") | |
| document, rendered, marker_images = result | |
| caption_map = self._extract_captions(document) | |
| figures = {} | |
| tables = {} | |
| image_count = 0 | |
| table_count = 0 | |
| for img_name, pil_image in marker_images.items(): | |
| caption_info = caption_map.get(img_name, {'captions': [], 'block_type': 'Unknown'}) | |
| if 'table' in img_name.lower() or 'Table' in img_name or caption_info.get('block_type') == 'Table': | |
| table_count += 1 | |
| path = assets_dir / f"table-{table_count}.png" | |
| pil_image.save(path, "PNG") | |
| tables[str(table_count)] = { | |
| 'caption': caption_info['captions'][0] if caption_info['captions'] else f"Table {table_count}", | |
| 'path': str(path), | |
| 'width': pil_image.width, | |
| 'height': pil_image.height, | |
| 'aspect': pil_image.width / pil_image.height if pil_image.height > 0 else 1, | |
| } | |
| else: | |
| image_count += 1 | |
| path = assets_dir / f"figure-{image_count}.png" | |
| pil_image.save(path, "PNG") | |
| figures[str(image_count)] = { | |
| 'caption': caption_info['captions'][0] if caption_info['captions'] else f"Figure {image_count}", | |
| 'path': str(path), | |
| 'width': pil_image.width, | |
| 'height': pil_image.height, | |
| 'aspect': pil_image.width / pil_image.height if pil_image.height > 0 else 1, | |
| } | |
| with open(assets_dir / "figures.json", 'w', encoding='utf-8') as f: | |
| json.dump(figures, f, indent=2) | |
| with open(assets_dir / "tables.json", 'w', encoding='utf-8') as f: | |
| json.dump(tables, f, indent=2) | |
| with open(assets_dir / "fig_tab_caption_mapping.json", 'w', encoding='utf-8') as f: | |
| json.dump(caption_map, f, indent=2, ensure_ascii=False) | |
| return figures, tables | |
| def _extract_captions(self, document): | |
| caption_map = {} | |
| for page in document.pages: | |
| for block_id in page.structure: | |
| block = page.get_block(block_id) | |
| if block.block_type in [BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup]: | |
| child_blocks = block.structure_blocks(page) | |
| figure_or_table = None | |
| captions = [] | |
| for child in child_blocks: | |
| child_block = page.get_block(child) | |
| if child_block.block_type in [BlockTypes.Figure, BlockTypes.Table, BlockTypes.Picture]: | |
| figure_or_table = child_block | |
| elif child_block.block_type in [BlockTypes.Caption, BlockTypes.Footnote]: | |
| captions.append(child_block.raw_text(document)) | |
| if figure_or_table: | |
| image_filename = f"{figure_or_table.id.to_path()}.jpeg" | |
| caption_map[image_filename] = { | |
| 'block_id': str(figure_or_table.id), | |
| 'block_type': str(figure_or_table.block_type), | |
| 'captions': captions, | |
| 'page': page.page_id | |
| } | |
| elif block.block_type in [BlockTypes.Figure, BlockTypes.Table, BlockTypes.Picture]: | |
| image_filename = f"{block.id.to_path()}.jpeg" | |
| if image_filename not in caption_map: | |
| nearby_captions = self._find_nearby_captions(page, block, document) | |
| caption_map[image_filename] = { | |
| 'block_id': str(block.id), | |
| 'block_type': str(block.block_type), | |
| 'captions': nearby_captions, | |
| 'page': page.page_id | |
| } | |
| return caption_map | |
| def _find_nearby_captions(self, page, target_block, document): | |
| captions = [] | |
| # Check all blocks on the page for captions | |
| for block_id in page.structure: | |
| block = page.get_block(block_id) | |
| if block.block_type in [BlockTypes.Caption, BlockTypes.Text]: | |
| caption_text = block.raw_text(document) | |
| # Look for figure/table keywords and check if it's nearby | |
| if any(keyword in caption_text for keyword in ['Figure', 'Table', 'Fig.']): | |
| captions.append(caption_text) | |
| # If no captions found, try previous/next blocks | |
| if not captions: | |
| for block in [page.get_prev_block(target_block), page.get_next_block(target_block)]: | |
| if block and block.block_type in [BlockTypes.Caption, BlockTypes.Text]: | |
| caption_text = block.raw_text(document) | |
| if any(keyword in caption_text for keyword in ['Figure', 'Table', 'Fig.']): | |
| captions.append(caption_text) | |
| return captions | |
| def _cleanup_unused_assets(self, output_dir: Path, name: str, images: Dict, tables: Dict): | |
| valid_paths = set() | |
| for img_data in images.values(): | |
| valid_paths.add(Path(img_data['path']).name) | |
| for table_data in tables.values(): | |
| valid_paths.add(Path(table_data['path']).name) | |
| for png_file in output_dir.glob(f"{name}-*.png"): | |
| if png_file.name not in valid_paths: | |
| png_file.unlink() | |
| def _extract_title_authors(self, text: str, config) -> Tuple[str, str]: | |
| """extract title and authors via llm api""" | |
| log_agent_info(self.name, "extracting title and authors with llm") | |
| agent = LangGraphAgent("expert academic paper parser", config) | |
| for attempt in range(3): | |
| try: | |
| prompt = Template(self.title_authors_prompt).render(markdown_document=text) | |
| agent.reset() | |
| response = agent.step(prompt) | |
| result = extract_json(response.content) | |
| if "title" in result and "authors" in result: | |
| title = result["title"].strip() | |
| authors = result["authors"].strip() | |
| # validate format | |
| if title and authors: | |
| return title, authors | |
| except Exception as e: | |
| log_agent_warning(self.name, f"title/authors extraction attempt {attempt + 1} failed: {e}") | |
| if attempt == 2: | |
| return "Untitled", "Authors not found" | |
| return "Untitled", "Authors not found" | |
| def _classify_visual_assets(self, figures: Dict, tables: Dict, raw_text: str, config) -> Tuple[Dict, int, int]: | |
| # combine all visuals for classification | |
| all_visuals = [] | |
| for fig_id, fig_data in figures.items(): | |
| all_visuals.append({ | |
| "id": f"figure_{fig_id}", | |
| "type": "figure", | |
| "caption": fig_data.get("caption", ""), | |
| "aspect_ratio": fig_data.get("aspect", 1.0) | |
| }) | |
| for tab_id, tab_data in tables.items(): | |
| all_visuals.append({ | |
| "id": f"table_{tab_id}", | |
| "type": "table", | |
| "caption": tab_data.get("caption", ""), | |
| "aspect_ratio": tab_data.get("aspect", 1.0) | |
| }) | |
| if not all_visuals: | |
| return {"key_visual": None, "problem_illustration": [], "method_workflow": [], "main_results": [], "comparative_results": [], "supporting": []}, 0, 0 | |
| log_agent_info(self.name, f"classifying {len(all_visuals)} visual assets") | |
| agent = LangGraphAgent("expert poster designer", config) | |
| for attempt in range(3): | |
| try: | |
| prompt = Template(self.visual_classification_prompt).render( | |
| visuals_list=json.dumps(all_visuals, indent=2) | |
| ) | |
| agent.reset() | |
| response = agent.step(prompt) | |
| classification = extract_json(response.content) | |
| # validate classification | |
| required_keys = ["key_visual", "problem_illustration", "method_workflow", "main_results", "comparative_results", "supporting"] | |
| if all(key in classification for key in required_keys): | |
| return classification, response.input_tokens, response.output_tokens | |
| except Exception as e: | |
| log_agent_warning(self.name, f"visual classification attempt {attempt + 1} failed: {e}") | |
| if attempt == 2: | |
| # fallback classification | |
| return self._fallback_visual_classification(all_visuals), 0, 0 | |
| return self._fallback_visual_classification(all_visuals), 0, 0 | |
| def _fallback_visual_classification(self, visuals): | |
| # simple rule-based fallback | |
| classification = {"key_visual": None, "main_results": [], "method_diagrams": [], "supporting": []} | |
| for visual in visuals: | |
| caption = visual.get("caption", "").lower() | |
| if "result" in caption or "performance" in caption or "comparison" in caption: | |
| classification["main_results"].append(visual["id"]) | |
| elif "method" in caption or "architecture" in caption or "framework" in caption: | |
| classification["method_diagrams"].append(visual["id"]) | |
| else: | |
| classification["supporting"].append(visual["id"]) | |
| # select key visual from main results or method diagrams | |
| if classification["main_results"]: | |
| classification["key_visual"] = classification["main_results"][0] | |
| elif classification["method_diagrams"]: | |
| classification["key_visual"] = classification["method_diagrams"][0] | |
| return classification | |
| def _extract_structured_sections(self, raw_text: str, config) -> Dict: | |
| """extract structured sections from raw paper text""" | |
| log_agent_info(self.name, "extracting structured sections from paper") | |
| agent = LangGraphAgent("expert paper section extractor", config) | |
| for attempt in range(3): | |
| try: | |
| prompt = Template(self.section_extraction_prompt).render(raw_text=raw_text) | |
| agent.reset() | |
| response = agent.step(prompt) | |
| structured_sections = extract_json(response.content) | |
| if self._validate_structured_sections(structured_sections): | |
| log_agent_success(self.name, f"extracted {len(structured_sections.get('paper_sections', []))} structured sections") | |
| return structured_sections | |
| else: | |
| log_agent_warning(self.name, f"attempt {attempt + 1}: invalid structured sections") | |
| except Exception as e: | |
| log_agent_warning(self.name, f"section extraction attempt {attempt + 1} failed: {e}") | |
| if attempt == 2: | |
| raise ValueError("failed to extract structured sections after multiple attempts") | |
| # fallback empty structure | |
| return { | |
| "paper_sections": [], | |
| "paper_structure": { | |
| "total_sections": 0, | |
| "foundation_sections": 0, | |
| "method_sections": 0, | |
| "evaluation_sections": 0, | |
| "conclusion_sections": 0 | |
| } | |
| } | |
| def _validate_structured_sections(self, structured_sections: Dict) -> bool: | |
| """validate structured sections format""" | |
| if "paper_sections" not in structured_sections: | |
| log_agent_warning(self.name, "validation error: missing 'paper_sections'") | |
| return False | |
| sections = structured_sections["paper_sections"] | |
| if not isinstance(sections, list) or len(sections) < 3: | |
| log_agent_warning(self.name, f"validation error: need at least 3 sections, got {len(sections)}") | |
| return False | |
| # validate each section | |
| for i, section in enumerate(sections): | |
| required_fields = ["section_name", "section_type", "content"] | |
| for field in required_fields: | |
| if field not in section: | |
| log_agent_warning(self.name, f"validation error: section {i} missing '{field}'") | |
| return False | |
| return True | |
| def parser_node(state: PosterState) -> PosterState: | |
| return Parser()(state) |