File size: 8,302 Bytes
eecfb09
27810b8
 
eecfb09
27810b8
 
eecfb09
c413127
5627b7a
27810b8
c413127
61053e3
eecfb09
27810b8
c413127
e03b33d
27810b8
552cafa
eecfb09
 
5d220a9
 
 
 
 
 
 
 
 
4c785d2
5d220a9
 
 
 
 
 
 
 
 
27810b8
5d220a9
 
 
 
 
4c785d2
5d220a9
 
 
 
 
 
 
 
4c785d2
 
 
 
 
7c57889
4c785d2
 
 
e4be8f0
 
5d220a9
 
 
e3d7dbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27810b8
e3d7dbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42b8733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27810b8
42b8733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61053e3
1ba19b2
 
 
 
 
 
 
 
 
 
 
 
 
 
27810b8
1ba19b2
 
 
 
 
 
 
 
 
 
a43286e
764a794
1ba19b2
 
 
 
 
 
 
 
 
 
 
 
 
 
552cafa
1ba19b2
 
 
 
 
27810b8
1ba19b2
 
 
 
 
552cafa
1ba19b2
 
 
 
 
 
 
c413127
1ba19b2
552cafa
 
 
 
 
c413127
552cafa
 
 
27810b8
552cafa
 
c413127
552cafa
c413127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552cafa
2e1d6d3
c413127
 
 
27810b8
2e1d6d3
 
c413127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import fitz
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain_core.callbacks import AsyncCallbackManagerForChainRun
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from PIL import Image

from src.chains.chain_funcs import get_param_or_default
from src.chains.prompts import JsonH1AndGDPrompt, SimpleVisionPrompt
from src.config.navigator import Navigator
from src.processing import image2base64, page2image

logger = logging.getLogger(__name__)


class FindPdfChain(Chain):
    """Chain for finding PDF file given substring of a filename"""

    navigator: Navigator = Navigator()

    @property
    def input_keys(self) -> List[str]:
        """Required input keys"""
        return ["pdf_path"]

    @property
    def output_keys(self) -> List[str]:
        """Output keys provided by the chain"""
        return ["pdf_path"]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Find PDF file by substring in filename

        Args:
            inputs: Dictionary containing:
                - pdf_path: Substring to search in PDF filenames or actual path
            run_manager: Callback manager

        Returns:
            Dictionary with found PDF path. If not found, pdf_path will be None

        Raises:
            ValueError: If multiple PDFs match the substring
        """
        fpath_or_name: Union[Path, str] = inputs["pdf_path"]

        if isinstance(fpath_or_name, str):
            pdf_path = self.navigator.find_file_by_substr(fpath_or_name)
            if pdf_path is None:
                raise ValueError(f"No PDF found matching '{fpath_or_name}'")
        else:
            pdf_path = Path(fpath_or_name)

        if not pdf_path.is_absolute():
            pdf_path = self.navigator.get_absolute_path(pdf_path)
        return dict(pdf_path=pdf_path)


class LoadPageChain(Chain):
    """Chain for loading PyMuPDF page"""

    @property
    def input_keys(self) -> List[str]:
        """Required input keys"""
        return ["pdf_path", "page_num"]

    @property
    def output_keys(self) -> List[str]:
        """Output keys provided by the chain"""
        return ["page"]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Load PyMuPDF page

        Args:
            inputs: Dictionary containing:
                - pdf_path: Path to PDF file
                - page_num: Page number to load
            run_manager: Callback manager

        Returns:
            Dictionary with PyMuPDF page
        """
        pdf_path: Path = inputs["pdf_path"]
        page_num: int = inputs["page_num"]

        pdf_file = fitz.open(pdf_path)
        page = pdf_file[page_num]

        return dict(page=page)


class Page2ImageChain(Chain):
    """Chain for converting PyMuPDF page to PIL Image"""

    def __init__(self, default_dpi: int = 72, **kwargs):
        """Initialize Page to Image conversion chain

        Args:
            default_dpi: Default resolution for PDF rendering
        """
        super().__init__(**kwargs)
        self._default_dpi = default_dpi

    @property
    def input_keys(self) -> List[str]:
        """Required input keys"""
        return ["page"]

    @property
    def output_keys(self) -> List[str]:
        """Output keys provided by the chain"""
        return ["image"]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Convert PyMuPDF page to PIL Image

        Args:
            inputs: Dictionary containing:
                - page: PyMuPDF page object
                - dpi: Optional DPI value for rendering
            run_manager: Callback manager

        Returns:
            Dictionary with PIL Image
        """
        page: fitz.Page = inputs["page"]
        dpi = get_param_or_default(inputs, "dpi", self._default_dpi)

        image = page2image(page, dpi)

        return dict(image=image)


class ImageEncodeChain(Chain):
    """Chain for encoding PIL Images to base64 strings"""

    @property
    def input_keys(self) -> List[str]:
        return ["image"]

    @property
    def output_keys(self) -> List[str]:
        return ["image_encoded"]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Encode PIL Image to base64 string

        Args:
            inputs: Dictionary with PIL Image
            run_manager: Callback manager

        Returns:
            Dictionary with base64 encoded image string
        """
        image: Image.Image = inputs["image"]
        encoded = image2base64(image)
        return dict(image_encoded=encoded)


class VisionAnalysisChain(Chain):
    """Single image analysis chain"""

    @property
    def input_keys(self) -> List[str]:
        """Required input keys for the chain"""
        return ["image_encoded"]

    @property
    def output_keys(self) -> List[str]:
        """Output keys provided by the chain"""
        return ["vision_prompt", "llm_output", "parsed_output"]

    def __init__(
        self,
        llm: Optional[ChatOpenAI] = None,
        prompt: str = "Describe this slide in detail",
        **kwargs,
    ):
        """Initialize the chain with vision capabilities

        Args:
            llm: Language model with vision capabilities (e.g. GPT-4V)
            prompt: An instructuion passed to vision model
        """
        super().__init__(**kwargs)

        # Store components as instance variables without class-level declarations
        self._llm = llm
        self._prompt = prompt

    def setup_chain(self, inputs: Dict[str, Any]):
        current_prompt = get_param_or_default(inputs, "vision_prompt", self._prompt)

        if isinstance(current_prompt, str):
            current_prompt = SimpleVisionPrompt(current_prompt)

        chain = (
            current_prompt.template  # type: ignore
            | self._llm
            | dict(
                llm_output=StrOutputParser(),
                message=RunnablePassthrough(),  # AIMessage(content)
            )
        )
        return chain, current_prompt

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Process single image with the vision model

        Args:
            inputs: Dictionary containing:
                - image: base64 encoded image string
                - vision_prompt: Optional custom prompt used instead of defined in __init__

        Returns:
            Dictionary with `analysis` - model's output
        """
        chain, current_prompt = self.setup_chain(inputs)

        out = chain.invoke(
            {"prompt": current_prompt, "image_base64": inputs["image_encoded"]}
        )

        result = dict(
            llm_output=out["llm_output"],  # type: ignore
            parsed_output=current_prompt.parse(out["llm_output"]),  # type: ignore
            response_metadata=out["message"].response_metadata,  # type: ignore
            vision_prompt=current_prompt.prompt_text,
        )
        return result

    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        chain, current_prompt = self.setup_chain(inputs)

        out = await chain.ainvoke(
            {"prompt": current_prompt, "image_base64": inputs["image_encoded"]}
        )

        result = dict(
            llm_output=out["llm_output"],  # type: ignore
            parsed_output=current_prompt.parse(out["llm_output"]),  # type: ignore
            response_metadata=out["message"].response_metadata,  # type: ignore
            vision_prompt=current_prompt.prompt_text,  # type: ignore
        )
        return result