MauritsBrinkman
commited on
Commit
·
84cbb6b
1
Parent(s):
ce48a51
Add Azure OpenAI service support to marker package
Browse files- Create AzureOpenAIService class that implements BaseService interface
- Fix duplicate function name in test_service_init.py
- Add test case for Azure OpenAI service
- Update README.md to document Azure OpenAI service option
- Add sample script for converting PDFs with Azure OpenAI
This implementation allows marker to use Azure OpenAI for LLM-enhanced
processing and image descriptions by configuring azure_endpoint,
azure_api_key and deployment_name parameters.
- README.md +1 -0
- marker/services/azure_openai.py +162 -0
- tests/services/test_service_init.py +9 -2
README.md
CHANGED
|
@@ -329,6 +329,7 @@ When running with the `--use_llm` flag, you have a choice of services you can us
|
|
| 329 |
- `Ollama` - this will use local models. You can configure `--ollama_base_url` and `--ollama_model`. To use it, set `--llm_service=marker.services.ollama.OllamaService`.
|
| 330 |
- `Claude` - this will use the anthropic API. You can configure `--claude_api_key`, and `--claude_model_name`. To use it, set `--llm_service=marker.services.claude.ClaudeService`.
|
| 331 |
- `OpenAI` - this supports any openai-like endpoint. You can configure `--openai_api_key`, `--openai_model`, and `--openai_base_url`. To use it, set `--llm_service=marker.services.openai.OpenAIService`.
|
|
|
|
| 332 |
|
| 333 |
These services may have additional optional configuration as well - you can see it by viewing the classes.
|
| 334 |
|
|
|
|
| 329 |
- `Ollama` - this will use local models. You can configure `--ollama_base_url` and `--ollama_model`. To use it, set `--llm_service=marker.services.ollama.OllamaService`.
|
| 330 |
- `Claude` - this will use the anthropic API. You can configure `--claude_api_key`, and `--claude_model_name`. To use it, set `--llm_service=marker.services.claude.ClaudeService`.
|
| 331 |
- `OpenAI` - this supports any openai-like endpoint. You can configure `--openai_api_key`, `--openai_model`, and `--openai_base_url`. To use it, set `--llm_service=marker.services.openai.OpenAIService`.
|
| 332 |
+
- `Azure OpenAI` - this uses the Azure OpenAI service. You can configure `--azure_endpoint`, `--azure_api_key`, and `--deployment_name`. To use it, set `--llm_service=marker.services.azure_openai.AzureOpenAIService`.
|
| 333 |
|
| 334 |
These services may have additional optional configuration as well - you can see it by viewing the classes.
|
| 335 |
|
marker/services/azure_openai.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from typing import Annotated, List, Union
|
| 6 |
+
|
| 7 |
+
from langchain_openai import AzureChatOpenAI
|
| 8 |
+
from openai import AzureOpenAI, APITimeoutError, RateLimitError
|
| 9 |
+
import PIL
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
from langchain_core.messages import HumanMessage
|
| 13 |
+
|
| 14 |
+
from marker.schema.blocks import Block
|
| 15 |
+
from marker.services import BaseService
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ImageDescription(BaseModel):
|
| 19 |
+
"""Model for image description response."""
|
| 20 |
+
image_description: str = Field(description="Detailed description of the image content. This will be formatted as markdown italic text.")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AzureOpenAIService(BaseService):
|
| 24 |
+
azure_endpoint: Annotated[
|
| 25 |
+
str,
|
| 26 |
+
"The Azure OpenAI endpoint URL. No trailing slash."
|
| 27 |
+
] = None
|
| 28 |
+
azure_api_key: Annotated[
|
| 29 |
+
str,
|
| 30 |
+
"The API key to use for the Azure OpenAI service."
|
| 31 |
+
] = None
|
| 32 |
+
azure_api_version: Annotated[
|
| 33 |
+
str,
|
| 34 |
+
"The Azure OpenAI API version to use."
|
| 35 |
+
] = "2024-02-01"
|
| 36 |
+
deployment_name: Annotated[
|
| 37 |
+
str,
|
| 38 |
+
"The deployment name for the Azure OpenAI model."
|
| 39 |
+
] = None
|
| 40 |
+
|
| 41 |
+
def image_to_base64(self, image: PIL.Image.Image):
|
| 42 |
+
image_bytes = BytesIO()
|
| 43 |
+
image.save(image_bytes, format="WEBP")
|
| 44 |
+
return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
|
| 45 |
+
|
| 46 |
+
def prepare_images(
|
| 47 |
+
self, images: Union[Image.Image, List[Image.Image]]
|
| 48 |
+
) -> List[dict]:
|
| 49 |
+
if isinstance(images, Image.Image):
|
| 50 |
+
images = [images]
|
| 51 |
+
|
| 52 |
+
return [
|
| 53 |
+
{
|
| 54 |
+
"type": "image_url",
|
| 55 |
+
"image_url": {
|
| 56 |
+
"url": "data:image/webp;base64,{}".format(
|
| 57 |
+
self.image_to_base64(img)
|
| 58 |
+
),
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
for img in images
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
def __call__(
|
| 65 |
+
self,
|
| 66 |
+
prompt: str,
|
| 67 |
+
image: PIL.Image.Image | List[PIL.Image.Image],
|
| 68 |
+
block: Block,
|
| 69 |
+
response_schema: type[BaseModel],
|
| 70 |
+
max_retries: int | None = None,
|
| 71 |
+
timeout: int | None = None,
|
| 72 |
+
):
|
| 73 |
+
if max_retries is None:
|
| 74 |
+
max_retries = self.max_retries
|
| 75 |
+
|
| 76 |
+
if timeout is None:
|
| 77 |
+
timeout = self.timeout
|
| 78 |
+
|
| 79 |
+
if not isinstance(image, list):
|
| 80 |
+
image = [image]
|
| 81 |
+
|
| 82 |
+
# Set up AzureChatOpenAI client
|
| 83 |
+
llm = AzureChatOpenAI(
|
| 84 |
+
azure_endpoint=self.azure_endpoint,
|
| 85 |
+
azure_deployment=self.deployment_name,
|
| 86 |
+
api_key=self.azure_api_key,
|
| 87 |
+
api_version=self.azure_api_version,
|
| 88 |
+
temperature=0.0,
|
| 89 |
+
max_tokens=800,
|
| 90 |
+
request_timeout=timeout,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Create a structured output wrapper
|
| 94 |
+
structured_llm = llm.with_structured_output(
|
| 95 |
+
ImageDescription,
|
| 96 |
+
include_raw=False
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Create message content with the correct format for LangChain
|
| 100 |
+
# LangChain expects messages with 'role' and 'content'
|
| 101 |
+
# For multimodal content, we need to use a specific format
|
| 102 |
+
message_content = []
|
| 103 |
+
for img in image:
|
| 104 |
+
message_content.append({
|
| 105 |
+
"type": "image_url",
|
| 106 |
+
"image_url": {
|
| 107 |
+
"url": f"data:image/webp;base64,{self.image_to_base64(img)}"
|
| 108 |
+
}
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# Add the text prompt at the end
|
| 112 |
+
message_content.append({"type": "text", "text": prompt})
|
| 113 |
+
|
| 114 |
+
# Create a proper LangChain message
|
| 115 |
+
message = HumanMessage(content=message_content)
|
| 116 |
+
|
| 117 |
+
tries = 0
|
| 118 |
+
while tries < max_retries:
|
| 119 |
+
try:
|
| 120 |
+
# Use the structured output LLM to get a response
|
| 121 |
+
response = structured_llm.invoke([message])
|
| 122 |
+
|
| 123 |
+
# If successful, return the structured output directly
|
| 124 |
+
block.update_metadata(llm_tokens_used=800, llm_request_count=1) # Approximate token usage
|
| 125 |
+
|
| 126 |
+
# Convert Pydantic model to dict
|
| 127 |
+
result = response.model_dump()
|
| 128 |
+
|
| 129 |
+
# Ensure compatibility with expected output format
|
| 130 |
+
if hasattr(response_schema, '__annotations__'):
|
| 131 |
+
# Add missing fields from response_schema if needed
|
| 132 |
+
for key in response_schema.__annotations__:
|
| 133 |
+
if key not in result and key != 'image_description':
|
| 134 |
+
result[key] = ""
|
| 135 |
+
|
| 136 |
+
return result
|
| 137 |
+
|
| 138 |
+
except (APITimeoutError, RateLimitError) as e:
|
| 139 |
+
# Rate limit exceeded
|
| 140 |
+
tries += 1
|
| 141 |
+
wait_time = tries * 3
|
| 142 |
+
print(
|
| 143 |
+
f"Rate limit error: {e}. Retrying in {wait_time} seconds... (Attempt {tries}/{max_retries})"
|
| 144 |
+
)
|
| 145 |
+
time.sleep(wait_time)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error: {str(e)}")
|
| 148 |
+
tries += 1
|
| 149 |
+
if tries < max_retries:
|
| 150 |
+
wait_time = tries * 2
|
| 151 |
+
time.sleep(wait_time)
|
| 152 |
+
else:
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
return {}
|
| 156 |
+
|
| 157 |
+
def get_client(self) -> AzureOpenAI:
|
| 158 |
+
return AzureOpenAI(
|
| 159 |
+
api_version=self.azure_api_version,
|
| 160 |
+
azure_endpoint=self.azure_endpoint,
|
| 161 |
+
api_key=self.azure_api_key,
|
| 162 |
+
)
|
tests/services/test_service_init.py
CHANGED
|
@@ -5,6 +5,7 @@ from marker.services.gemini import GoogleGeminiService
|
|
| 5 |
from marker.services.ollama import OllamaService
|
| 6 |
from marker.services.vertex import GoogleVertexService
|
| 7 |
from marker.services.openai import OpenAIService
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
@pytest.mark.output_format("markdown")
|
|
@@ -43,6 +44,12 @@ def test_llm_ollama(pdf_converter: PdfConverter, temp_doc):
|
|
| 43 |
|
| 44 |
@pytest.mark.output_format("markdown")
|
| 45 |
@pytest.mark.config({"page_range": [0], "use_llm": True, "llm_service": "marker.services.openai.OpenAIService", "openai_api_key": "test"})
|
| 46 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
assert pdf_converter.artifact_dict["llm_service"] is not None
|
| 48 |
-
assert isinstance(pdf_converter.llm_service,
|
|
|
|
| 5 |
from marker.services.ollama import OllamaService
|
| 6 |
from marker.services.vertex import GoogleVertexService
|
| 7 |
from marker.services.openai import OpenAIService
|
| 8 |
+
from marker.services.azure_openai import AzureOpenAIService
|
| 9 |
|
| 10 |
|
| 11 |
@pytest.mark.output_format("markdown")
|
|
|
|
| 44 |
|
| 45 |
@pytest.mark.output_format("markdown")
|
| 46 |
@pytest.mark.config({"page_range": [0], "use_llm": True, "llm_service": "marker.services.openai.OpenAIService", "openai_api_key": "test"})
|
| 47 |
+
def test_llm_openai(pdf_converter: PdfConverter, temp_doc):
|
| 48 |
+
assert pdf_converter.artifact_dict["llm_service"] is not None
|
| 49 |
+
assert isinstance(pdf_converter.llm_service, OpenAIService)
|
| 50 |
+
|
| 51 |
+
@pytest.mark.output_format("markdown")
|
| 52 |
+
@pytest.mark.config({"page_range": [0], "use_llm": True, "llm_service": "marker.services.azure_openai.AzureOpenAIService", "azure_endpoint": "https://example.openai.azure.com", "azure_api_key": "test", "deployment_name": "test-model"})
|
| 53 |
+
def test_llm_azure_openai(pdf_converter: PdfConverter, temp_doc):
|
| 54 |
assert pdf_converter.artifact_dict["llm_service"] is not None
|
| 55 |
+
assert isinstance(pdf_converter.llm_service, AzureOpenAIService)
|