PaperShow / Paper2Poster /docling /models /picture_description_api_model.py
ZaynZhu
Clean version without large assets
7c08dc3
import base64
import io
import logging
from typing import Iterable, List, Optional
import requests
from PIL import Image
from pydantic import BaseModel, ConfigDict
from docling.datamodel.pipeline_options import PictureDescriptionApiOptions
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
_log = logging.getLogger(__name__)
class ChatMessage(BaseModel):
role: str
content: str
class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)
id: str
model: Optional[str] = None # returned by openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# elements_batch_size = 4
def __init__(self, enabled: bool, options: PictureDescriptionApiOptions):
super().__init__(enabled=enabled, options=options)
self.options: PictureDescriptionApiOptions
if self.enabled:
if options.url.host != "localhost":
raise NotImplementedError(
"The options try to connect to remote APIs which are not yet allowed."
)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
for image in images:
img_io = io.BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
],
}
]
payload = {
"messages": messages,
**self.options.params,
}
r = requests.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.ok:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()
api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
yield generated_text