Spaces:
Build error
Build error
| import asyncio | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import fitz | |
| from langchain.chains.base import Chain | |
| from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun | |
| from langchain_openai import ChatOpenAI | |
| from pydantic import BaseModel, ConfigDict, Field, field_serializer | |
| from tqdm import tqdm | |
| from src.chains.chains import ( | |
| ImageEncodeChain, | |
| LoadPageChain, | |
| Page2ImageChain, | |
| VisionAnalysisChain, | |
| ) | |
| from src.chains.prompts import BasePrompt, JsonH1AndGDPrompt | |
| from src.config.navigator import Navigator | |
| logger = logging.getLogger(__name__) | |
| SlideDescription = JsonH1AndGDPrompt.SlideDescription | |
| class SlideAnalysis(BaseModel): | |
| """Container for slide analysis results""" | |
| pdf_path: Path | |
| page_num: int | |
| vision_prompt: Optional[str] | |
| llm_output: str | |
| response_metadata: dict = dict() | |
| parsed_output: SlideDescription = SlideDescription() | |
| def serialize_path(self, pdf_path): | |
| return str(Navigator().get_relative_path(pdf_path)) | |
| def reset_vision_prompt(self): | |
| """Reset vision prompt""" | |
| self.vision_prompt = None | |
| class PresentationAnalysis(BaseModel): | |
| """Container for presentation analysis results""" | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| name: str | |
| path: Path | |
| vision_prompt: str | |
| metadata: Dict = Field(default_factory=dict) | |
| slides: List[SlideAnalysis] = Field(default_factory=list) | |
| timestamp: str = Field(default_factory=lambda: datetime.now().isoformat()) | |
| def serialize_vision_prompt(self, vision_prompt): | |
| return ( | |
| vision_prompt.prompt_text | |
| if isinstance(vision_prompt, BasePrompt) | |
| else vision_prompt | |
| ) | |
| def serialize_path(self, pdf_path): | |
| return str(Navigator().get_relative_path(pdf_path)) | |
| def save(self, save_path: Path): | |
| """Save analysis results to JSON""" | |
| data = self.model_dump() | |
| with open(save_path, "w", encoding="utf-8") as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| def load(cls, load_path: Path) -> "PresentationAnalysis": | |
| """Load analysis results from JSON""" | |
| with open(load_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # Convert string back to Path | |
| data["path"] = Navigator().get_absolute_path(Path(data["path"])) | |
| return cls(**data) | |
| class SingleSlidePipeline(Chain): | |
| """Pipeline for processing single slide from PDF""" | |
| def __init__( | |
| self, | |
| llm: Optional[ChatOpenAI] = None, | |
| vision_prompt: str = "Describe this slide in detail", | |
| dpi: int = 72, | |
| return_steps: bool = True, | |
| **kwargs, | |
| ): | |
| """Initialize pipeline for single slide processing | |
| Args: | |
| llm: Language model with vision capabilities | |
| vision_prompt: Prompt for slide analysis | |
| dpi: Resolution for PDF rendering | |
| return_steps: Whether to return intermediate chain outputs | |
| """ | |
| super().__init__(**kwargs) | |
| self._chain = ( | |
| LoadPageChain() | |
| | Page2ImageChain(default_dpi=dpi) | |
| | ImageEncodeChain() | |
| | VisionAnalysisChain(llm=llm, prompt=vision_prompt) | |
| ) | |
| self._return_steps = return_steps | |
| def input_keys(self) -> List[str]: | |
| """Required input keys""" | |
| return ["pdf_path", "page_num"] | |
| def output_keys(self) -> List[str]: | |
| """Output keys provided by the chain""" | |
| keys = ["slide_analysis"] | |
| if self._return_steps: | |
| keys.append("chain_outputs") | |
| return keys | |
| def _call( | |
| self, inputs: Dict[str, Any], run_manager: Optional[Any] = None | |
| ) -> Dict[str, Any]: | |
| """Process single slide | |
| Args: | |
| inputs: Dictionary containing: | |
| - pdf_path: Path to PDF file | |
| - page_num: Page number to process | |
| Returns: | |
| Dictionary with SlideAnalysis object and optionally chain outputs | |
| """ | |
| chain_outputs = self._chain.invoke(inputs) | |
| result = dict(slide_analysis=SlideAnalysis(**chain_outputs)) | |
| self.log_result(result) | |
| if self._return_steps: | |
| result["chain_outputs"] = chain_outputs | |
| return result | |
| def log_result(self, result: Dict[str, Any]): | |
| slide_analysis = result["slide_analysis"] | |
| page_num = slide_analysis.page_num | |
| pres_name = slide_analysis.pdf_path.stem | |
| out_len = len(slide_analysis.llm_output) | |
| logger.info( | |
| f"Returned {out_len} symbols " | |
| f"for Slide {page_num} " | |
| f"in Presentation '{pres_name}'" | |
| ) | |
| if out_len < 30: | |
| logger.warning(f"Slide {page_num} in Presentation '{pres_name}' was not processed") | |
| async def _acall( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| """Process single slide asynchronously""" | |
| chain_outputs = await self._chain.ainvoke(inputs) | |
| result = dict(slide_analysis=SlideAnalysis(**chain_outputs)) | |
| self.log_result(result) | |
| if self._return_steps: | |
| result["chain_outputs"] = chain_outputs | |
| return result | |
| class PresentationPipeline(Chain): | |
| """Pipeline for processing entire PDF presentation""" | |
| navigator: Navigator = Navigator() | |
| def __init__( | |
| self, | |
| llm: Optional[ChatOpenAI] = None, | |
| vision_prompt: str = "Describe this slide in detail", | |
| dpi: int = 72, | |
| base_path: Optional[Path] = None, | |
| fresh_start: bool = True, | |
| save_steps: bool = True, | |
| save_final: bool = True, | |
| max_concurrency: int = 5, | |
| **kwargs, | |
| ): | |
| """Initialize pipeline for full presentation processing | |
| Args: | |
| llm: Language model with vision capabilities | |
| vision_prompt: Prompt for slide analysis | |
| dpi: Resolution for PDF rendering | |
| base_path: Base path for storing analysis results | |
| """ | |
| super().__init__(**kwargs) | |
| self._vision_prompt = str(vision_prompt) | |
| self._slide_pipeline = SingleSlidePipeline( | |
| llm=llm, vision_prompt=vision_prompt, dpi=dpi | |
| ) | |
| self._base_path = base_path | |
| self._fresh_start = fresh_start | |
| self._save_steps = save_steps | |
| self._save_final = save_final | |
| self._semaphore = asyncio.Semaphore(max_concurrency) | |
| def input_keys(self) -> List[str]: | |
| """Required input keys""" | |
| return ["pdf_path"] | |
| def output_keys(self) -> List[str]: | |
| """Output keys provided by the chain""" | |
| return ["presentation"] | |
| def _get_timestamped_filename(self, fname: str) -> str: | |
| """Generate timestamped filename for analysis results | |
| Args: | |
| prefix: Prefix for the filename (usually presentation name) | |
| Returns: | |
| String with format: fname_YYYYMMDD-HHMMSS.json | |
| """ | |
| timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| return f"{fname}_{timestamp}.json" | |
| def _get_interim_save_path(self, pdf_path: Path) -> Path: | |
| """Get path for saving interim results""" | |
| interim_dir = ( | |
| self.navigator.get_interim_path(pdf_path.stem) | |
| if self._base_path is None | |
| else self._base_path | |
| ) | |
| interim_dir.mkdir(parents=True, exist_ok=True) | |
| filename = self._get_timestamped_filename(pdf_path.stem) | |
| return interim_dir / filename | |
| def _find_latest_analysis(self, pdf_path: Path) -> Optional[Path]: | |
| """Find most recent analysis file for the presentation | |
| Args: | |
| pdf_path: Path to PDF file | |
| Returns: | |
| Path to latest analysis file or None if not found | |
| """ | |
| search_dir = ( | |
| self._base_path | |
| if self._base_path | |
| else self.navigator.get_interim_path(pdf_path.stem) | |
| ) | |
| if not search_dir.exists(): | |
| return None | |
| analyses = list(search_dir.glob(f"{pdf_path.stem}_*.json")) | |
| return max(analyses, default=None, key=lambda p: p.stat().st_mtime) | |
| def _process_slide(self, pdf_path: Path, page_num: int) -> Optional[SlideAnalysis]: | |
| """Process single slide with error handling""" | |
| try: | |
| result = self._slide_pipeline.invoke( | |
| {"pdf_path": pdf_path, "page_num": page_num} | |
| ) | |
| slide_analysis = result["slide_analysis"] | |
| slide_analysis.reset_vision_prompt() | |
| return slide_analysis | |
| except Exception as e: | |
| logger.error(f"Failed to process slide {page_num}: {str(e)}") | |
| return None | |
| def _call( | |
| self, inputs: Dict[str, Any], run_manager: Optional[Any] = None | |
| ) -> Dict[str, Any]: | |
| """Process entire presentation | |
| Args: | |
| inputs: Dictionary containing: | |
| - pdf_path: Path to PDF file | |
| Returns: | |
| Dictionary with PresentationAnalysis object | |
| """ | |
| pdf_path = Path(inputs["pdf_path"]) | |
| latest_analysis = self._find_latest_analysis(pdf_path) | |
| save_path = self._get_interim_save_path(pdf_path) | |
| # Try to load existing results | |
| presentation = ( | |
| PresentationAnalysis.load(latest_analysis) | |
| if latest_analysis and not self._fresh_start | |
| else PresentationAnalysis( | |
| name=pdf_path.stem, path=pdf_path, vision_prompt=self._vision_prompt | |
| ) | |
| ) | |
| # Get set of already processed pages | |
| processed_pages = {slide.page_num for slide in presentation.slides} | |
| if processed_pages: | |
| logger.info(f"Loaded existing analysis with {len(processed_pages)} slides") | |
| # Get number of pages and metadata | |
| doc = fitz.open(pdf_path) | |
| num_pages = len(doc) | |
| # Update metadata if not present | |
| if not presentation.metadata and doc.metadata is not None: | |
| presentation.metadata = dict( | |
| page_count=num_pages, | |
| title=doc.metadata.get("title", ""), | |
| author=doc.metadata.get("author", ""), | |
| subject=doc.metadata.get("subject", ""), | |
| keywords=doc.metadata.get("keywords", ""), | |
| ) | |
| # Process remaining slides | |
| remaining_pages = [i for i in range(num_pages) if i not in processed_pages] | |
| if remaining_pages: | |
| for page_num in tqdm(remaining_pages, desc="Processing slides"): | |
| slide = self._process_slide(pdf_path, page_num) | |
| if slide: | |
| presentation.slides.append(slide) | |
| # Save progress after each slide | |
| if self._save_steps: | |
| presentation.save(save_path) | |
| # Sort slides by page number | |
| presentation.slides.sort(key=lambda x: x.page_num) | |
| if self._save_final: | |
| presentation.save(save_path) | |
| return dict(presentation=presentation) | |
| async def _aprocess_slide( | |
| self, pdf_path: Path, page_num: int | |
| ) -> Optional[SlideAnalysis]: | |
| """Process single slide with error handling asynchronously""" | |
| try: | |
| result = await self._slide_pipeline.ainvoke( | |
| {"pdf_path": pdf_path, "page_num": page_num} | |
| ) | |
| slide_analysis = result["slide_analysis"] | |
| slide_analysis.reset_vision_prompt() | |
| return slide_analysis | |
| except Exception as e: | |
| logger.error(f"Failed to process slide {page_num}: {str(e)}") | |
| return None | |
| async def _process_slide_with_semaphore( | |
| self, pdf_path: Path, page_num: int | |
| ) -> Optional[SlideAnalysis]: | |
| """Process single slide with semaphore-controlled concurrency""" | |
| async with self._semaphore: | |
| return await self._aprocess_slide(pdf_path, page_num) | |
| async def _process_slides_in_batches( | |
| self, | |
| pdf_path: Path, | |
| remaining_pages: List[int], | |
| presentation: PresentationAnalysis, | |
| save_path: Path, | |
| ) -> None: | |
| """Process slides with controlled concurrency and save progress | |
| Args: | |
| pdf_path: Path to PDF file | |
| remaining_pages: List of page numbers to process | |
| presentation: Current presentation analysis | |
| save_path: Path to save results | |
| """ | |
| tasks = [ | |
| self._process_slide_with_semaphore(pdf_path, page_num) | |
| for page_num in remaining_pages | |
| ] | |
| for task in tqdm( | |
| asyncio.as_completed(tasks), | |
| desc=f"Processing slides (max {self._semaphore._value} concurrent)", | |
| total=len(tasks), | |
| ): | |
| slide = await task | |
| if slide: | |
| presentation.slides.append(slide) | |
| if self._save_steps: | |
| presentation.save(save_path) | |
| async def _acall( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| """Process entire presentation asynchronously with controlled concurrency""" | |
| pdf_path = Path(inputs["pdf_path"]) | |
| latest_analysis = self._find_latest_analysis(pdf_path) | |
| save_path = self._get_interim_save_path(pdf_path) | |
| # Try to load existing results | |
| presentation = ( | |
| PresentationAnalysis.load(latest_analysis) | |
| if latest_analysis and not self._fresh_start | |
| else PresentationAnalysis( | |
| name=pdf_path.stem, path=pdf_path, vision_prompt=self._vision_prompt | |
| ) | |
| ) | |
| # Get set of already processed pages | |
| processed_pages = {slide.page_num for slide in presentation.slides} | |
| if processed_pages: | |
| logger.info(f"Loaded existing analysis with {len(processed_pages)} slides") | |
| # Get number of pages and metadata | |
| doc = fitz.open(pdf_path) | |
| num_pages = len(doc) | |
| # Update metadata if not present | |
| if not presentation.metadata: | |
| presentation.metadata = dict( | |
| page_count=num_pages, | |
| title=doc.metadata.get("title", ""), | |
| author=doc.metadata.get("author", ""), | |
| subject=doc.metadata.get("subject", ""), | |
| keywords=doc.metadata.get("keywords", ""), | |
| ) | |
| # Process remaining slides with controlled concurrency | |
| remaining_pages = [i for i in range(num_pages) if i not in processed_pages] | |
| if remaining_pages: | |
| await self._process_slides_in_batches( | |
| pdf_path, remaining_pages, presentation, save_path | |
| ) | |
| if self._save_final: | |
| presentation.save(save_path) | |
| # self.log_result(presentation) | |
| return dict(presentation=presentation) | |
| def log_result(self, presentation: PresentationAnalysis): | |
| pres_name = presentation.name | |
| logger.info(f"Finished processing {pres_name}") | |