Ilia Tambovtsev commited on
Commit
1ba19b2
·
1 Parent(s): eeea597

style: reorder chains

Browse files

Chains for single page go first. Then the batched ones.

Files changed (1) hide show
  1. src/pdf_utils/chains.py +107 -100
src/pdf_utils/chains.py CHANGED
@@ -144,6 +144,113 @@ class Page2ImageChain(Chain):
144
  return dict(image=image)
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  class Pdf2ImageChain(Chain):
148
  """Chain for converting PDF pages to PIL Images using PyMuPDF"""
149
 
@@ -249,42 +356,6 @@ class Pdf2ImageChain(Chain):
249
  return result
250
 
251
 
252
- class ImageEncodeChain(Chain):
253
- """Chain for encoding PIL Images to base64 strings"""
254
-
255
- @property
256
- def input_keys(self) -> List[str]:
257
- return ["image"]
258
-
259
- @property
260
- def output_keys(self) -> List[str]:
261
- return ["image_encoded"]
262
-
263
- def _call(
264
- self,
265
- inputs: Dict[str, Any],
266
- run_manager: Optional[CallbackManagerForChainRun] = None
267
- ) -> Dict[str, Any]:
268
- """Encode PIL Image to base64 string
269
-
270
- Args:
271
- inputs: Dictionary with PIL Image
272
- run_manager: Callback manager
273
-
274
- Returns:
275
- Dictionary with base64 encoded image string
276
- """
277
- image: Image.Image = inputs["image"]
278
-
279
- # Save image to bytes buffer
280
- buffer = BytesIO()
281
- image.save(buffer, format="PNG")
282
-
283
- # Encode to base64
284
- encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
285
-
286
- return dict(image_encoded=encoded)
287
-
288
  class PDFLoaderChain(Chain):
289
  """Chain for loading PDF paths from weird-slides directory"""
290
 
@@ -369,67 +440,3 @@ class ImageLoaderChain(Chain):
369
  return {"image": image_base64}
370
 
371
 
372
- class VisionAnalysisChain(Chain):
373
- """Single image analysis chain"""
374
-
375
- @property
376
- def input_keys(self) -> List[str]:
377
- """Required input keys for the chain"""
378
- return ["image_encoded"]
379
-
380
- @property
381
- def output_keys(self) -> List[str]:
382
- """Output keys provided by the chain"""
383
- return ["llm_output"]
384
-
385
- def __init__(
386
- self,
387
- llm: Optional[ChatOpenAI] = None,
388
- prompt: str = "Describe this slide in detail",
389
- **kwargs
390
- ):
391
- """Initialize the chain with vision capabilities
392
-
393
- Args:
394
- llm: Language model with vision capabilities (e.g. GPT-4V)
395
- prompt: Custom prompt for slide analysis
396
- """
397
- super().__init__(**kwargs)
398
-
399
- # Store components as instance variables without class-level declarations
400
- self._llm = llm
401
- self._prompt = prompt
402
-
403
- self._vision_prompt_template = ChatPromptTemplate.from_messages([
404
- ("human", [
405
- {"type": "text", "text": "{prompt}"},
406
- {
407
- "type": "image",
408
- "image_url": "data:image/png;base64,{image_base64}"
409
- }
410
- ])
411
- ])
412
-
413
- self._chain = (
414
- self._vision_prompt_template
415
- | self._llm
416
- | dict(llm_output=StrOutputParser())
417
- )
418
-
419
- def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
420
- """Process single image with the vision model
421
-
422
- Args:
423
- inputs: Dictionary containing:
424
- - image: base64 encoded image string
425
- - vision_prompt: Optional custom prompt used instead of defined in __init__
426
-
427
- Returns:
428
- Dictionary with `analysis` - model's output
429
- """
430
- current_prompt = get_param_or_default(inputs, "vision_prompt", self._prompt)
431
-
432
- return self._chain.invoke({
433
- "prompt": current_prompt,
434
- "image_base64": inputs["image_encoded"]
435
- })
 
144
  return dict(image=image)
145
 
146
 
147
+ class ImageEncodeChain(Chain):
148
+ """Chain for encoding PIL Images to base64 strings"""
149
+
150
+ @property
151
+ def input_keys(self) -> List[str]:
152
+ return ["image"]
153
+
154
+ @property
155
+ def output_keys(self) -> List[str]:
156
+ return ["image_encoded"]
157
+
158
+ def _call(
159
+ self,
160
+ inputs: Dict[str, Any],
161
+ run_manager: Optional[CallbackManagerForChainRun] = None
162
+ ) -> Dict[str, Any]:
163
+ """Encode PIL Image to base64 string
164
+
165
+ Args:
166
+ inputs: Dictionary with PIL Image
167
+ run_manager: Callback manager
168
+
169
+ Returns:
170
+ Dictionary with base64 encoded image string
171
+ """
172
+ image: Image.Image = inputs["image"]
173
+
174
+ # Save image to bytes buffer
175
+ buffer = BytesIO()
176
+ image.save(buffer, format="PNG")
177
+
178
+ # Encode to base64
179
+ encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
180
+
181
+ return dict(image_encoded=encoded)
182
+
183
+
184
+ class VisionAnalysisChain(Chain):
185
+ """Single image analysis chain"""
186
+
187
+ @property
188
+ def input_keys(self) -> List[str]:
189
+ """Required input keys for the chain"""
190
+ return ["image_encoded"]
191
+
192
+ @property
193
+ def output_keys(self) -> List[str]:
194
+ """Output keys provided by the chain"""
195
+ return ["llm_output"]
196
+
197
+ def __init__(
198
+ self,
199
+ llm: Optional[ChatOpenAI] = None,
200
+ prompt: str = "Describe this slide in detail",
201
+ **kwargs
202
+ ):
203
+ """Initialize the chain with vision capabilities
204
+
205
+ Args:
206
+ llm: Language model with vision capabilities (e.g. GPT-4V)
207
+ prompt: Custom prompt for slide analysis
208
+ """
209
+ super().__init__(**kwargs)
210
+
211
+ # Store components as instance variables without class-level declarations
212
+ self._llm = llm
213
+ self._prompt = prompt
214
+
215
+ self._vision_prompt_template = ChatPromptTemplate.from_messages([
216
+ ("human", [
217
+ {"type": "text", "text": "{prompt}"},
218
+ {
219
+ "type": "image",
220
+ "image_url": "data:image/png;base64,{image_base64}"
221
+ }
222
+ ])
223
+ ])
224
+
225
+ self._chain = (
226
+ self._vision_prompt_template
227
+ | self._llm
228
+ | dict(llm_output=StrOutputParser())
229
+ )
230
+
231
+ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
232
+ """Process single image with the vision model
233
+
234
+ Args:
235
+ inputs: Dictionary containing:
236
+ - image: base64 encoded image string
237
+ - vision_prompt: Optional custom prompt used instead of defined in __init__
238
+
239
+ Returns:
240
+ Dictionary with `analysis` - model's output
241
+ """
242
+ current_prompt = get_param_or_default(inputs, "vision_prompt", self._prompt)
243
+
244
+ return self._chain.invoke({
245
+ "prompt": current_prompt,
246
+ "image_base64": inputs["image_encoded"]
247
+ })
248
+
249
+
250
+ # Further chains are for batched processing.
251
+ # I created them during the first runs.
252
+ # Probably should remove them but will keep for later
253
+
254
  class Pdf2ImageChain(Chain):
255
  """Chain for converting PDF pages to PIL Images using PyMuPDF"""
256
 
 
356
  return result
357
 
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  class PDFLoaderChain(Chain):
360
  """Chain for loading PDF paths from weird-slides directory"""
361
 
 
440
  return {"image": image_base64}
441
 
442