Vik Paruchuri commited on
Commit
ed75e8a
·
1 Parent(s): 3e15f78

Refactor llm services

Browse files
marker/processors/llm/llm_image_description.py CHANGED
@@ -9,15 +9,15 @@ from typing import Annotated, List
9
 
10
 
11
  class LLMImageDescriptionProcessor(BaseLLMSimpleBlockProcessor):
12
- block_types = (BlockTypes.Picture, BlockTypes.Figure,)
13
- extract_images: Annotated[
14
- bool,
15
- "Extract images from the document."
16
- ] = True
17
  image_description_prompt: Annotated[
18
  str,
19
  "The prompt to use for generating image descriptions.",
20
- "Default is a string containing the Gemini prompt."
21
  ] = """You are a document analysis expert who specializes in creating text descriptions for images.
22
  You will receive an image of a picture or figure. Your job will be to create a short description of the image.
23
  **Instructions:**
@@ -41,26 +41,34 @@ In this figure, a bar chart titled "Fruit Preference Survey" is showing the numb
41
 
42
  def inference_blocks(self, document: Document) -> List[BlockData]:
43
  blocks = super().inference_blocks(document)
 
 
44
  return blocks
45
 
46
  def block_prompts(self, document: Document) -> List[PromptData]:
47
  prompt_data = []
48
  for block_data in self.inference_blocks(document):
49
  block = block_data["block"]
50
- prompt = self.image_description_prompt.replace("{raw_text}", block.raw_text(document))
 
 
51
  image = self.extract_image(document, block)
52
 
53
- prompt_data.append({
54
- "prompt": prompt,
55
- "image": image,
56
- "block": block,
57
- "schema": ImageSchema,
58
- "page": block_data["page"]
59
- })
 
 
60
 
61
  return prompt_data
62
 
63
- def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
 
 
64
  block = prompt_data["block"]
65
 
66
  if not response or "image_description" not in response:
@@ -74,5 +82,6 @@ In this figure, a bar chart titled "Fruit Preference Survey" is showing the numb
74
 
75
  block.description = image_description
76
 
 
77
  class ImageSchema(BaseModel):
78
  image_description: str
 
9
 
10
 
11
  class LLMImageDescriptionProcessor(BaseLLMSimpleBlockProcessor):
12
+ block_types = (
13
+ BlockTypes.Picture,
14
+ BlockTypes.Figure,
15
+ )
16
+ extract_images: Annotated[bool, "Extract images from the document."] = True
17
  image_description_prompt: Annotated[
18
  str,
19
  "The prompt to use for generating image descriptions.",
20
+ "Default is a string containing the Gemini prompt.",
21
  ] = """You are a document analysis expert who specializes in creating text descriptions for images.
22
  You will receive an image of a picture or figure. Your job will be to create a short description of the image.
23
  **Instructions:**
 
41
 
42
  def inference_blocks(self, document: Document) -> List[BlockData]:
43
  blocks = super().inference_blocks(document)
44
+ if self.extract_images:
45
+ return []
46
  return blocks
47
 
48
  def block_prompts(self, document: Document) -> List[PromptData]:
49
  prompt_data = []
50
  for block_data in self.inference_blocks(document):
51
  block = block_data["block"]
52
+ prompt = self.image_description_prompt.replace(
53
+ "{raw_text}", block.raw_text(document)
54
+ )
55
  image = self.extract_image(document, block)
56
 
57
+ prompt_data.append(
58
+ {
59
+ "prompt": prompt,
60
+ "image": image,
61
+ "block": block,
62
+ "schema": ImageSchema,
63
+ "page": block_data["page"],
64
+ }
65
+ )
66
 
67
  return prompt_data
68
 
69
+ def rewrite_block(
70
+ self, response: dict, prompt_data: PromptData, document: Document
71
+ ):
72
  block = prompt_data["block"]
73
 
74
  if not response or "image_description" not in response:
 
82
 
83
  block.description = image_description
84
 
85
+
86
  class ImageSchema(BaseModel):
87
  image_description: str
marker/services/__init__.py CHANGED
@@ -1,10 +1,12 @@
1
  from typing import Optional, List, Annotated
 
2
 
3
  import PIL
4
  from pydantic import BaseModel
5
 
6
  from marker.schema.blocks import Block
7
  from marker.util import assign_config, verify_config_keys
 
8
 
9
 
10
  class BaseService:
@@ -14,6 +16,11 @@ class BaseService:
14
  ] = 2
15
  retry_wait_time: Annotated[int, "The wait time between retries."] = 3
16
 
 
 
 
 
 
17
  def process_images(self, images: List[PIL.Image.Image]) -> list:
18
  raise NotImplementedError
19
 
 
1
  from typing import Optional, List, Annotated
2
+ from io import BytesIO
3
 
4
  import PIL
5
  from pydantic import BaseModel
6
 
7
  from marker.schema.blocks import Block
8
  from marker.util import assign_config, verify_config_keys
9
+ import base64
10
 
11
 
12
  class BaseService:
 
16
  ] = 2
17
  retry_wait_time: Annotated[int, "The wait time between retries."] = 3
18
 
19
+ def img_to_base64(self, img: PIL.Image.Image):
20
+ image_bytes = BytesIO()
21
+ img.save(image_bytes, format="WEBP")
22
+ return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
23
+
24
  def process_images(self, images: List[PIL.Image.Image]) -> list:
25
  raise NotImplementedError
26
 
marker/services/azure_openai.py CHANGED
@@ -1,8 +1,6 @@
1
- import base64
2
  import json
3
  import time
4
- from io import BytesIO
5
- from typing import Annotated, List, Union
6
 
7
  import PIL
8
  from marker.logger import get_logger
@@ -18,30 +16,17 @@ logger = get_logger()
18
 
19
  class AzureOpenAIService(BaseService):
20
  azure_endpoint: Annotated[
21
- str,
22
- "The Azure OpenAI endpoint URL. No trailing slash."
23
  ] = None
24
  azure_api_key: Annotated[
25
- str,
26
- "The API key to use for the Azure OpenAI service."
27
- ] = None
28
- azure_api_version: Annotated[
29
- str,
30
- "The Azure OpenAI API version to use."
31
  ] = None
 
32
  deployment_name: Annotated[
33
- str,
34
- "The deployment name for the Azure OpenAI model."
35
  ] = None
36
 
37
- def image_to_base64(self, image: PIL.Image.Image):
38
- image_bytes = BytesIO()
39
- image.save(image_bytes, format="WEBP")
40
- return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
41
-
42
- def prepare_images(
43
- self, images: Union[Image.Image, List[Image.Image]]
44
- ) -> List[dict]:
45
  if isinstance(images, Image.Image):
46
  images = [images]
47
 
@@ -49,10 +34,8 @@ class AzureOpenAIService(BaseService):
49
  {
50
  "type": "image_url",
51
  "image_url": {
52
- "url": "data:image/webp;base64,{}".format(
53
- self.image_to_base64(img)
54
- ),
55
- }
56
  }
57
  for img in images
58
  ]
@@ -60,8 +43,8 @@ class AzureOpenAIService(BaseService):
60
  def __call__(
61
  self,
62
  prompt: str,
63
- image: PIL.Image.Image | List[PIL.Image.Image],
64
- block: Block,
65
  response_schema: type[BaseModel],
66
  max_retries: int | None = None,
67
  timeout: int | None = None,
@@ -72,11 +55,8 @@ class AzureOpenAIService(BaseService):
72
  if timeout is None:
73
  timeout = self.timeout
74
 
75
- if not isinstance(image, list):
76
- image = [image]
77
-
78
  client = self.get_client()
79
- image_data = self.prepare_images(image)
80
 
81
  messages = [
82
  {
@@ -94,7 +74,7 @@ class AzureOpenAIService(BaseService):
94
  response = client.beta.chat.completions.parse(
95
  extra_headers={
96
  "X-Title": "Marker",
97
- "HTTP-Referer": "https://github.com/VikParuchuri/marker",
98
  },
99
  model=self.deployment_name,
100
  messages=messages,
@@ -124,4 +104,4 @@ class AzureOpenAIService(BaseService):
124
  api_version=self.azure_api_version,
125
  azure_endpoint=self.azure_endpoint,
126
  api_key=self.azure_api_key,
127
- )
 
 
1
  import json
2
  import time
3
+ from typing import Annotated, List
 
4
 
5
  import PIL
6
  from marker.logger import get_logger
 
16
 
17
  class AzureOpenAIService(BaseService):
18
  azure_endpoint: Annotated[
19
+ str, "The Azure OpenAI endpoint URL. No trailing slash."
 
20
  ] = None
21
  azure_api_key: Annotated[
22
+ str, "The API key to use for the Azure OpenAI service."
 
 
 
 
 
23
  ] = None
24
+ azure_api_version: Annotated[str, "The Azure OpenAI API version to use."] = None
25
  deployment_name: Annotated[
26
+ str, "The deployment name for the Azure OpenAI model."
 
27
  ] = None
28
 
29
+ def process_images(self, images: List[PIL.Image.Image]) -> list:
 
 
 
 
 
 
 
30
  if isinstance(images, Image.Image):
31
  images = [images]
32
 
 
34
  {
35
  "type": "image_url",
36
  "image_url": {
37
+ "url": "data:image/webp;base64,{}".format(self.img_to_base64(img)),
38
+ },
 
 
39
  }
40
  for img in images
41
  ]
 
43
  def __call__(
44
  self,
45
  prompt: str,
46
+ image: PIL.Image.Image | List[PIL.Image.Image] | None,
47
+ block: Block | None,
48
  response_schema: type[BaseModel],
49
  max_retries: int | None = None,
50
  timeout: int | None = None,
 
55
  if timeout is None:
56
  timeout = self.timeout
57
 
 
 
 
58
  client = self.get_client()
59
+ image_data = self.format_image_for_llm(image)
60
 
61
  messages = [
62
  {
 
74
  response = client.beta.chat.completions.parse(
75
  extra_headers={
76
  "X-Title": "Marker",
77
+ "HTTP-Referer": "https://github.com/datalab-to/marker",
78
  },
79
  model=self.deployment_name,
80
  messages=messages,
 
104
  api_version=self.azure_api_version,
105
  azure_endpoint=self.azure_endpoint,
106
  api_key=self.azure_api_key,
107
+ )
marker/services/claude.py CHANGED
@@ -1,7 +1,5 @@
1
- import base64
2
  import json
3
  import time
4
- from io import BytesIO
5
  from typing import List, Annotated, T
6
 
7
  import PIL
@@ -26,11 +24,6 @@ class ClaudeService(BaseService):
26
  int, "The maximum number of tokens to use for a single Claude request."
27
  ] = 8192
28
 
29
- def img_to_base64(self, img: PIL.Image.Image):
30
- image_bytes = BytesIO()
31
- img.save(image_bytes, format="WEBP")
32
- return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
33
-
34
  def process_images(self, images: List[Image.Image]) -> List[dict]:
35
  return [
36
  {
 
 
1
  import json
2
  import time
 
3
  from typing import List, Annotated, T
4
 
5
  import PIL
 
24
  int, "The maximum number of tokens to use for a single Claude request."
25
  ] = 8192
26
 
 
 
 
 
 
27
  def process_images(self, images: List[Image.Image]) -> List[dict]:
28
  return [
29
  {
marker/services/ollama.py CHANGED
@@ -1,6 +1,4 @@
1
- import base64
2
  import json
3
- from io import BytesIO
4
  from typing import Annotated, List
5
 
6
  import PIL
@@ -22,13 +20,8 @@ class OllamaService(BaseService):
22
  "llama3.2-vision"
23
  )
24
 
25
- def image_to_base64(self, image: PIL.Image.Image):
26
- image_bytes = BytesIO()
27
- image.save(image_bytes, format="PNG")
28
- return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
29
-
30
  def process_images(self, images):
31
- image_bytes = [self.image_to_base64(img) for img in images]
32
  return image_bytes
33
 
34
  def __call__(
 
 
1
  import json
 
2
  from typing import Annotated, List
3
 
4
  import PIL
 
20
  "llama3.2-vision"
21
  )
22
 
 
 
 
 
 
23
  def process_images(self, images):
24
+ image_bytes = [self.img_to_base64(img) for img in images]
25
  return image_bytes
26
 
27
  def __call__(
marker/services/openai.py CHANGED
@@ -1,7 +1,5 @@
1
- import base64
2
  import json
3
  import time
4
- from io import BytesIO
5
  from typing import Annotated, List
6
 
7
  import openai
@@ -32,21 +30,6 @@ class OpenAIService(BaseService):
32
  "The image format to use for the OpenAI-like service. Use 'png' for better compatability",
33
  ] = "webp"
34
 
35
- def image_to_base64(self, image: PIL.Image.Image) -> str:
36
- """
37
- Convert PIL image to base64 string
38
-
39
- Args:
40
- image: Input PIL Image
41
- format: Format to use for the image; use "png" for better compatability.
42
-
43
- Returns:
44
- Base-64 encoded image (in PNG format) to pass to LLM.
45
- """
46
- image_bytes = BytesIO()
47
- image.save(image_bytes, format=self.openai_image_format)
48
- return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
49
-
50
  def process_images(self, images: List[Image.Image]) -> List[dict]:
51
  """
52
  Generate the base-64 encoded message to send to an
@@ -67,7 +50,7 @@ class OpenAIService(BaseService):
67
  "type": "image_url",
68
  "image_url": {
69
  "url": "data:image/{};base64,{}".format(
70
- self.openai_image_format, self.image_to_base64(img)
71
  ),
72
  },
73
  }
 
 
1
  import json
2
  import time
 
3
  from typing import Annotated, List
4
 
5
  import openai
 
30
  "The image format to use for the OpenAI-like service. Use 'png' for better compatability",
31
  ] = "webp"
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def process_images(self, images: List[Image.Image]) -> List[dict]:
34
  """
35
  Generate the base-64 encoded message to send to an
 
50
  "type": "image_url",
51
  "image_url": {
52
  "url": "data:image/{};base64,{}".format(
53
+ self.openai_image_format, self.img_to_base64(img)
54
  ),
55
  },
56
  }