import logging from pathlib import Path from typing import Any, Dict, List, Optional, Union import fitz from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain_core.callbacks import AsyncCallbackManagerForChainRun from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_openai import ChatOpenAI from PIL import Image from src.chains.chain_funcs import get_param_or_default from src.chains.prompts import JsonH1AndGDPrompt, SimpleVisionPrompt from src.config.navigator import Navigator from src.processing import image2base64, page2image logger = logging.getLogger(__name__) class FindPdfChain(Chain): """Chain for finding PDF file given substring of a filename""" navigator: Navigator = Navigator() @property def input_keys(self) -> List[str]: """Required input keys""" return ["pdf_path"] @property def output_keys(self) -> List[str]: """Output keys provided by the chain""" return ["pdf_path"] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Find PDF file by substring in filename Args: inputs: Dictionary containing: - pdf_path: Substring to search in PDF filenames or actual path run_manager: Callback manager Returns: Dictionary with found PDF path. If not found, pdf_path will be None Raises: ValueError: If multiple PDFs match the substring """ fpath_or_name: Union[Path, str] = inputs["pdf_path"] if isinstance(fpath_or_name, str): pdf_path = self.navigator.find_file_by_substr(fpath_or_name) if pdf_path is None: raise ValueError(f"No PDF found matching '{fpath_or_name}'") else: pdf_path = Path(fpath_or_name) if not pdf_path.is_absolute(): pdf_path = self.navigator.get_absolute_path(pdf_path) return dict(pdf_path=pdf_path) class LoadPageChain(Chain): """Chain for loading PyMuPDF page""" @property def input_keys(self) -> List[str]: """Required input keys""" return ["pdf_path", "page_num"] @property def output_keys(self) -> List[str]: """Output keys provided by the chain""" return ["page"] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Load PyMuPDF page Args: inputs: Dictionary containing: - pdf_path: Path to PDF file - page_num: Page number to load run_manager: Callback manager Returns: Dictionary with PyMuPDF page """ pdf_path: Path = inputs["pdf_path"] page_num: int = inputs["page_num"] pdf_file = fitz.open(pdf_path) page = pdf_file[page_num] return dict(page=page) class Page2ImageChain(Chain): """Chain for converting PyMuPDF page to PIL Image""" def __init__(self, default_dpi: int = 72, **kwargs): """Initialize Page to Image conversion chain Args: default_dpi: Default resolution for PDF rendering """ super().__init__(**kwargs) self._default_dpi = default_dpi @property def input_keys(self) -> List[str]: """Required input keys""" return ["page"] @property def output_keys(self) -> List[str]: """Output keys provided by the chain""" return ["image"] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Convert PyMuPDF page to PIL Image Args: inputs: Dictionary containing: - page: PyMuPDF page object - dpi: Optional DPI value for rendering run_manager: Callback manager Returns: Dictionary with PIL Image """ page: fitz.Page = inputs["page"] dpi = get_param_or_default(inputs, "dpi", self._default_dpi) image = page2image(page, dpi) return dict(image=image) class ImageEncodeChain(Chain): """Chain for encoding PIL Images to base64 strings""" @property def input_keys(self) -> List[str]: return ["image"] @property def output_keys(self) -> List[str]: return ["image_encoded"] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Encode PIL Image to base64 string Args: inputs: Dictionary with PIL Image run_manager: Callback manager Returns: Dictionary with base64 encoded image string """ image: Image.Image = inputs["image"] encoded = image2base64(image) return dict(image_encoded=encoded) class VisionAnalysisChain(Chain): """Single image analysis chain""" @property def input_keys(self) -> List[str]: """Required input keys for the chain""" return ["image_encoded"] @property def output_keys(self) -> List[str]: """Output keys provided by the chain""" return ["vision_prompt", "llm_output", "parsed_output"] def __init__( self, llm: Optional[ChatOpenAI] = None, prompt: str = "Describe this slide in detail", **kwargs, ): """Initialize the chain with vision capabilities Args: llm: Language model with vision capabilities (e.g. GPT-4V) prompt: An instructuion passed to vision model """ super().__init__(**kwargs) # Store components as instance variables without class-level declarations self._llm = llm self._prompt = prompt def setup_chain(self, inputs: Dict[str, Any]): current_prompt = get_param_or_default(inputs, "vision_prompt", self._prompt) if isinstance(current_prompt, str): current_prompt = SimpleVisionPrompt(current_prompt) chain = ( current_prompt.template # type: ignore | self._llm | dict( llm_output=StrOutputParser(), message=RunnablePassthrough(), # AIMessage(content) ) ) return chain, current_prompt def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Process single image with the vision model Args: inputs: Dictionary containing: - image: base64 encoded image string - vision_prompt: Optional custom prompt used instead of defined in __init__ Returns: Dictionary with `analysis` - model's output """ chain, current_prompt = self.setup_chain(inputs) out = chain.invoke( {"prompt": current_prompt, "image_base64": inputs["image_encoded"]} ) result = dict( llm_output=out["llm_output"], # type: ignore parsed_output=current_prompt.parse(out["llm_output"]), # type: ignore response_metadata=out["message"].response_metadata, # type: ignore vision_prompt=current_prompt.prompt_text, ) return result async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: chain, current_prompt = self.setup_chain(inputs) out = await chain.ainvoke( {"prompt": current_prompt, "image_base64": inputs["image_encoded"]} ) result = dict( llm_output=out["llm_output"], # type: ignore parsed_output=current_prompt.parse(out["llm_output"]), # type: ignore response_metadata=out["message"].response_metadata, # type: ignore vision_prompt=current_prompt.prompt_text, # type: ignore ) return result