Spaces:
Build error
Build error
File size: 8,302 Bytes
eecfb09 27810b8 eecfb09 27810b8 eecfb09 c413127 5627b7a 27810b8 c413127 61053e3 eecfb09 27810b8 c413127 e03b33d 27810b8 552cafa eecfb09 5d220a9 4c785d2 5d220a9 27810b8 5d220a9 4c785d2 5d220a9 4c785d2 7c57889 4c785d2 e4be8f0 5d220a9 e3d7dbf 27810b8 e3d7dbf 42b8733 27810b8 42b8733 61053e3 1ba19b2 27810b8 1ba19b2 a43286e 764a794 1ba19b2 552cafa 1ba19b2 27810b8 1ba19b2 552cafa 1ba19b2 c413127 1ba19b2 552cafa c413127 552cafa 27810b8 552cafa c413127 552cafa c413127 552cafa 2e1d6d3 c413127 27810b8 2e1d6d3 c413127 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fitz
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain_core.callbacks import AsyncCallbackManagerForChainRun
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from PIL import Image
from src.chains.chain_funcs import get_param_or_default
from src.chains.prompts import JsonH1AndGDPrompt, SimpleVisionPrompt
from src.config.navigator import Navigator
from src.processing import image2base64, page2image
logger = logging.getLogger(__name__)
class FindPdfChain(Chain):
"""Chain for finding PDF file given substring of a filename"""
navigator: Navigator = Navigator()
@property
def input_keys(self) -> List[str]:
"""Required input keys"""
return ["pdf_path"]
@property
def output_keys(self) -> List[str]:
"""Output keys provided by the chain"""
return ["pdf_path"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Find PDF file by substring in filename
Args:
inputs: Dictionary containing:
- pdf_path: Substring to search in PDF filenames or actual path
run_manager: Callback manager
Returns:
Dictionary with found PDF path. If not found, pdf_path will be None
Raises:
ValueError: If multiple PDFs match the substring
"""
fpath_or_name: Union[Path, str] = inputs["pdf_path"]
if isinstance(fpath_or_name, str):
pdf_path = self.navigator.find_file_by_substr(fpath_or_name)
if pdf_path is None:
raise ValueError(f"No PDF found matching '{fpath_or_name}'")
else:
pdf_path = Path(fpath_or_name)
if not pdf_path.is_absolute():
pdf_path = self.navigator.get_absolute_path(pdf_path)
return dict(pdf_path=pdf_path)
class LoadPageChain(Chain):
"""Chain for loading PyMuPDF page"""
@property
def input_keys(self) -> List[str]:
"""Required input keys"""
return ["pdf_path", "page_num"]
@property
def output_keys(self) -> List[str]:
"""Output keys provided by the chain"""
return ["page"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Load PyMuPDF page
Args:
inputs: Dictionary containing:
- pdf_path: Path to PDF file
- page_num: Page number to load
run_manager: Callback manager
Returns:
Dictionary with PyMuPDF page
"""
pdf_path: Path = inputs["pdf_path"]
page_num: int = inputs["page_num"]
pdf_file = fitz.open(pdf_path)
page = pdf_file[page_num]
return dict(page=page)
class Page2ImageChain(Chain):
"""Chain for converting PyMuPDF page to PIL Image"""
def __init__(self, default_dpi: int = 72, **kwargs):
"""Initialize Page to Image conversion chain
Args:
default_dpi: Default resolution for PDF rendering
"""
super().__init__(**kwargs)
self._default_dpi = default_dpi
@property
def input_keys(self) -> List[str]:
"""Required input keys"""
return ["page"]
@property
def output_keys(self) -> List[str]:
"""Output keys provided by the chain"""
return ["image"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Convert PyMuPDF page to PIL Image
Args:
inputs: Dictionary containing:
- page: PyMuPDF page object
- dpi: Optional DPI value for rendering
run_manager: Callback manager
Returns:
Dictionary with PIL Image
"""
page: fitz.Page = inputs["page"]
dpi = get_param_or_default(inputs, "dpi", self._default_dpi)
image = page2image(page, dpi)
return dict(image=image)
class ImageEncodeChain(Chain):
"""Chain for encoding PIL Images to base64 strings"""
@property
def input_keys(self) -> List[str]:
return ["image"]
@property
def output_keys(self) -> List[str]:
return ["image_encoded"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Encode PIL Image to base64 string
Args:
inputs: Dictionary with PIL Image
run_manager: Callback manager
Returns:
Dictionary with base64 encoded image string
"""
image: Image.Image = inputs["image"]
encoded = image2base64(image)
return dict(image_encoded=encoded)
class VisionAnalysisChain(Chain):
"""Single image analysis chain"""
@property
def input_keys(self) -> List[str]:
"""Required input keys for the chain"""
return ["image_encoded"]
@property
def output_keys(self) -> List[str]:
"""Output keys provided by the chain"""
return ["vision_prompt", "llm_output", "parsed_output"]
def __init__(
self,
llm: Optional[ChatOpenAI] = None,
prompt: str = "Describe this slide in detail",
**kwargs,
):
"""Initialize the chain with vision capabilities
Args:
llm: Language model with vision capabilities (e.g. GPT-4V)
prompt: An instructuion passed to vision model
"""
super().__init__(**kwargs)
# Store components as instance variables without class-level declarations
self._llm = llm
self._prompt = prompt
def setup_chain(self, inputs: Dict[str, Any]):
current_prompt = get_param_or_default(inputs, "vision_prompt", self._prompt)
if isinstance(current_prompt, str):
current_prompt = SimpleVisionPrompt(current_prompt)
chain = (
current_prompt.template # type: ignore
| self._llm
| dict(
llm_output=StrOutputParser(),
message=RunnablePassthrough(), # AIMessage(content)
)
)
return chain, current_prompt
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process single image with the vision model
Args:
inputs: Dictionary containing:
- image: base64 encoded image string
- vision_prompt: Optional custom prompt used instead of defined in __init__
Returns:
Dictionary with `analysis` - model's output
"""
chain, current_prompt = self.setup_chain(inputs)
out = chain.invoke(
{"prompt": current_prompt, "image_base64": inputs["image_encoded"]}
)
result = dict(
llm_output=out["llm_output"], # type: ignore
parsed_output=current_prompt.parse(out["llm_output"]), # type: ignore
response_metadata=out["message"].response_metadata, # type: ignore
vision_prompt=current_prompt.prompt_text,
)
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
chain, current_prompt = self.setup_chain(inputs)
out = await chain.ainvoke(
{"prompt": current_prompt, "image_base64": inputs["image_encoded"]}
)
result = dict(
llm_output=out["llm_output"], # type: ignore
parsed_output=current_prompt.parse(out["llm_output"]), # type: ignore
response_metadata=out["message"].response_metadata, # type: ignore
vision_prompt=current_prompt.prompt_text, # type: ignore
)
return result
|