Ilia Tambovtsev commited on
Commit
0eae301
·
1 Parent(s): 742cb9c

feat: add pipelines

Browse files
Files changed (1) hide show
  1. src/pdf_utils/pipelines.py +313 -0
src/pdf_utils/pipelines.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ from pydantic import BaseModel, Field
3
+ from pathlib import Path
4
+ import json
5
+ import logging
6
+ from tqdm import tqdm
7
+ from datetime import datetime
8
+ import fitz
9
+
10
+ from langchain_openai.chat_models import ChatOpenAI
11
+ from langchain.chains.base import Chain
12
+
13
+ from src.pdf_utils.chains import (
14
+ LoadPageChain,
15
+ Page2ImageChain,
16
+ ImageEncodeChain,
17
+ VisionAnalysisChain
18
+ )
19
+
20
+ from src.config import Navigator
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SlideAnalysis(BaseModel):
27
+ """Container for slide analysis results"""
28
+ page_num: int
29
+ vision_prompt: str
30
+ content: str
31
+
32
+
33
+ class PresentationAnalysis(BaseModel):
34
+ """Container for presentation analysis results"""
35
+ name: str
36
+ path: Path
37
+ metadata: Dict = Field(default_factory=dict)
38
+ slides: List[SlideAnalysis] = Field(default_factory=list)
39
+ timestamp: str = Field(
40
+ default_factory=lambda: datetime.now().isoformat()
41
+ )
42
+
43
+ def save(self, save_path: Path):
44
+ """Save analysis results to JSON"""
45
+ data = self.model_dump()
46
+ # Convert Path to string for JSON serialization
47
+ data["path"] = str(data["path"])
48
+
49
+ with open(save_path, "w", encoding="utf-8") as f:
50
+ json.dump(data, f, indent=2, ensure_ascii=False)
51
+
52
+ @classmethod
53
+ def load(cls, load_path: Path) -> "PresentationAnalysis":
54
+ """Load analysis results from JSON"""
55
+ with open(load_path, "r", encoding="utf-8") as f:
56
+ data = json.load(f)
57
+ # Convert string back to Path
58
+ data["path"] = Path(data["path"])
59
+ return cls(**data)
60
+
61
+
62
+ class SingleSlidePipeline(Chain):
63
+ """Pipeline for processing single slide from PDF"""
64
+
65
+ def __init__(
66
+ self,
67
+ llm: Optional[ChatOpenAI] = None,
68
+ vision_prompt: str = "Describe this slide in detail",
69
+ dpi: int = 75,
70
+ **kwargs
71
+ ):
72
+ """Initialize pipeline for single slide processing
73
+
74
+ Args:
75
+ llm: Language model with vision capabilities
76
+ vision_prompt: Prompt for slide analysis
77
+ dpi: Resolution for PDF rendering
78
+ """
79
+ super().__init__(**kwargs)
80
+
81
+ # Create processing pipeline using pipe operator
82
+ self._chain = (
83
+ LoadPageChain()
84
+ | Page2ImageChain(default_dpi=dpi)
85
+ | ImageEncodeChain()
86
+ | VisionAnalysisChain(llm=llm, prompt=vision_prompt)
87
+ )
88
+
89
+ @property
90
+ def input_keys(self) -> List[str]:
91
+ """Required input keys"""
92
+ return ["pdf_path", "page_num"]
93
+
94
+ @property
95
+ def output_keys(self) -> List[str]:
96
+ """Output keys provided by the chain"""
97
+ return ["slide_analysis"]
98
+
99
+ def _call(
100
+ self,
101
+ inputs: Dict[str, Any],
102
+ run_manager: Optional[Any] = None
103
+ ) -> Dict[str, Any]:
104
+ """Process single slide
105
+
106
+ Args:
107
+ inputs: Dictionary containing:
108
+ - pdf_path: Path to PDF file
109
+ - page_num: Page number to process
110
+
111
+ Returns:
112
+ Dictionary with SlideAnalysis object
113
+ """
114
+ result = self._chain.invoke(inputs)
115
+ return dict(
116
+ slide_analysis=SlideAnalysis(
117
+ page_num=inputs["page_num"],
118
+ vision_prompt=result["vision_prompt"],
119
+ content=result["llm_output"]
120
+ )
121
+ )
122
+
123
+
124
+ class PresentationPipeline(Chain):
125
+ """Pipeline for processing entire PDF presentation"""
126
+
127
+ navigator: Navigator = Navigator()
128
+
129
+ def __init__(
130
+ self,
131
+ llm: Optional[ChatOpenAI] = None,
132
+ vision_prompt: str = "Describe this slide in detail",
133
+ dpi: int = 75,
134
+ base_path: Optional[Path] = None,
135
+ save_steps: bool = True,
136
+ save_final: bool = True,
137
+ **kwargs
138
+ ):
139
+ """Initialize pipeline for full presentation processing
140
+
141
+ Args:
142
+ llm: Language model with vision capabilities
143
+ vision_prompt: Prompt for slide analysis
144
+ dpi: Resolution for PDF rendering
145
+ base_path: Base path for storing analysis results
146
+ """
147
+ super().__init__(**kwargs)
148
+ self._slide_pipeline = SingleSlidePipeline(
149
+ llm=llm,
150
+ vision_prompt=vision_prompt,
151
+ dpi=dpi
152
+ )
153
+ self._base_path = base_path
154
+ self._save_steps = save_steps
155
+ self._save_final = save_final
156
+
157
+ @property
158
+ def input_keys(self) -> List[str]:
159
+ """Required input keys"""
160
+ return ["pdf_path"]
161
+
162
+ @property
163
+ def output_keys(self) -> List[str]:
164
+ """Output keys provided by the chain"""
165
+ return ["presentation"]
166
+
167
+ def _get_timestamped_filename(self, prefix: str) -> str:
168
+ """Generate timestamped filename for analysis results
169
+
170
+ Args:
171
+ prefix: Prefix for the filename (usually presentation name)
172
+
173
+ Returns:
174
+ String with format: prefix_YYYYMMDD_HHMMSS.json
175
+ """
176
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
177
+ return f"{prefix}_{timestamp}.json"
178
+
179
+ def _get_interim_save_path(self, pdf_path: Path) -> Path:
180
+ """Get path for saving interim results"""
181
+ interim_dir = (
182
+ self.navigator.get_interim_path(pdf_path.stem)
183
+ if self._base_path is None
184
+ else self._base_path
185
+ )
186
+
187
+ interim_dir.mkdir(parents=True, exist_ok=True)
188
+ filename = self._get_timestamped_filename(pdf_path.stem)
189
+ return interim_dir / filename
190
+
191
+ def _find_latest_analysis(self, pdf_path: Path) -> Optional[Path]:
192
+ """Find most recent analysis file for the presentation
193
+
194
+ Args:
195
+ pdf_path: Path to PDF file
196
+
197
+ Returns:
198
+ Path to latest analysis file or None if not found
199
+ """
200
+ search_dir = (
201
+ self._base_path if self._base_path
202
+ else self.navigator.get_interim_path(pdf_path.stem)
203
+ )
204
+
205
+ if not search_dir.exists():
206
+ return None
207
+
208
+ analyses = list(search_dir.glob(f"{pdf_path.stem}_*.json"))
209
+ return max(analyses, default=None, key=lambda p: p.stat().st_mtime)
210
+
211
+ def _process_slide(self, pdf_path: Path, page_num: int) -> Optional[SlideAnalysis]:
212
+ """Process single slide with error handling"""
213
+ try:
214
+ result = self._slide_pipeline.invoke({
215
+ "pdf_path": pdf_path,
216
+ "page_num": page_num
217
+ })
218
+ return result["slide_analysis"]
219
+ except Exception as e:
220
+ logger.error(f"Failed to process slide {page_num}: {str(e)}")
221
+ return None
222
+
223
+ def _call(
224
+ self,
225
+ inputs: Dict[str, Any],
226
+ run_manager: Optional[Any] = None
227
+ ) -> Dict[str, Any]:
228
+ """Process entire presentation
229
+
230
+ Args:
231
+ inputs: Dictionary containing:
232
+ - pdf_path: Path to PDF file
233
+
234
+ Returns:
235
+ Dictionary with PresentationAnalysis object
236
+ """
237
+ pdf_path = Path(inputs["pdf_path"])
238
+ latest_analysis = self._find_latest_analysis(pdf_path)
239
+ save_path = self._get_interim_save_path(pdf_path)
240
+
241
+ # Try to load existing results
242
+ presentation = (
243
+ PresentationAnalysis.load(latest_analysis)
244
+ if latest_analysis
245
+ else PresentationAnalysis(name=pdf_path.stem, path=pdf_path)
246
+ )
247
+
248
+ # Get set of already processed pages
249
+ processed_pages = {slide.page_num for slide in presentation.slides}
250
+
251
+ if processed_pages:
252
+ logger.info(f"Loaded existing analysis with {len(processed_pages)} slides")
253
+
254
+ # Get number of pages and metadata
255
+ doc = fitz.open(pdf_path)
256
+ num_pages = len(doc)
257
+
258
+ # Update metadata if not present
259
+ if not presentation.metadata:
260
+ presentation.metadata = dict(
261
+ page_count=num_pages,
262
+ title=doc.metadata.get("title", ""),
263
+ author=doc.metadata.get("author", ""),
264
+ subject=doc.metadata.get("subject", ""),
265
+ keywords=doc.metadata.get("keywords", "")
266
+ )
267
+
268
+ # Process remaining slides
269
+ remaining_pages = [i for i in range(num_pages) if i not in processed_pages]
270
+
271
+ if remaining_pages:
272
+ for page_num in tqdm(remaining_pages, desc="Processing slides"):
273
+ slide = self._process_slide(pdf_path, page_num)
274
+ if slide:
275
+ presentation.slides.append(slide)
276
+ # Save progress after each slide
277
+ if self._save_steps:
278
+ presentation.save(save_path)
279
+
280
+ # Sort slides by page number
281
+ presentation.slides.sort(key=lambda x: x.page_num)
282
+
283
+ if self._save_final:
284
+ presentation.save(save_path)
285
+ return dict(presentation=presentation)
286
+
287
+
288
+ def process_presentation(
289
+ pdf_path: Path,
290
+ llm: Optional[ChatOpenAI] = None,
291
+ vision_prompt: str = "Describe this slide in detail",
292
+ dpi: int = 300,
293
+ base_path: Optional[Path] = None
294
+ ) -> PresentationAnalysis:
295
+ """Convenience function for presentation processing
296
+
297
+ Args:
298
+ pdf_path: Path to PDF file
299
+ llm: Language model with vision capabilities
300
+ vision_prompt: Prompt for slide analysis
301
+ dpi: Resolution for PDF rendering
302
+ base_path: Optional custom path for storing results
303
+
304
+ Returns:
305
+ PresentationAnalysis object
306
+ """
307
+ pipeline = PresentationPipeline(
308
+ llm=llm,
309
+ vision_prompt=vision_prompt,
310
+ dpi=dpi,
311
+ base_path=base_path
312
+ )
313
+ return pipeline.invoke({"pdf_path": pdf_path})["presentation"]