File size: 4,722 Bytes
e1ced8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from __future__ import annotations
import base64
import io
import os
import shutil
import uuid
from pathlib import Path
from PIL import Image
from state import ImageRef
class ImageStore:
"""Disk-based image manager. LangGraph state only carries lightweight
``ImageRef`` dicts; all heavy image bytes live on disk."""
def __init__(self, base_dir: str):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self._pages_dir = self.base_dir / "pages"
self._crops_dir = self.base_dir / "crops"
self._annotated_dir = self.base_dir / "annotated"
for d in (self._pages_dir, self._crops_dir, self._annotated_dir):
d.mkdir(exist_ok=True)
# ------------------------------------------------------------------
# Save helpers
# ------------------------------------------------------------------
def save_page_image(self, page_num: int, image_bytes: bytes) -> ImageRef:
img = Image.open(io.BytesIO(image_bytes))
fname = f"page_{page_num}.png"
path = self._pages_dir / fname
img.save(str(path), format="PNG")
return ImageRef(
id=f"page_{page_num}",
path=str(path),
label=f"Page {page_num} (full page)",
page_num=page_num,
crop_type="full_page",
width=img.width,
height=img.height,
)
def save_crop(
self,
page_num: int,
crop_id: str,
image: Image.Image,
label: str,
) -> ImageRef:
fname = f"page_{page_num}_{crop_id}.png"
path = self._crops_dir / fname
image.save(str(path), format="PNG")
return ImageRef(
id=f"page_{page_num}_{crop_id}",
path=str(path),
label=label,
page_num=page_num,
crop_type="crop",
width=image.width,
height=image.height,
)
def save_annotated(
self,
source_ref: ImageRef,
annotated_image: Image.Image,
) -> ImageRef:
ann_id = f"{source_ref['id']}_ann"
fname = f"{ann_id}.png"
path = self._annotated_dir / fname
annotated_image.save(str(path), format="PNG")
return ImageRef(
id=ann_id,
path=str(path),
label=f"{source_ref['label']} [annotated]",
page_num=source_ref["page_num"],
crop_type="annotated",
width=annotated_image.width,
height=annotated_image.height,
)
# ------------------------------------------------------------------
# Load helpers
# ------------------------------------------------------------------
def load_image(self, ref: ImageRef) -> Image.Image:
return Image.open(ref["path"])
def load_bytes(self, ref: ImageRef) -> bytes:
with open(ref["path"], "rb") as f:
return f.read()
def get_page_image_path(self, page_num: int) -> str:
return str(self._pages_dir / f"page_{page_num}.png")
def load_page_bytes(self, page_num: int) -> bytes:
path = self.get_page_image_path(page_num)
with open(path, "rb") as f:
return f.read()
# ------------------------------------------------------------------
# Format conversions for different model APIs
# ------------------------------------------------------------------
def to_gemini_part(self, ref: ImageRef):
"""Return a ``google.genai.types.Part`` for Gemini multimodal prompts."""
from google.genai import types
return types.Part.from_bytes(
data=self.load_bytes(ref),
mime_type="image/png",
)
def to_openai_base64(self, ref: ImageRef) -> dict:
"""Return an OpenAI-compatible image content block (base64 data URI)."""
b64 = base64.b64encode(self.load_bytes(ref)).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64}"},
}
def create_thumbnail(self, ref: ImageRef, max_size: int = 400) -> bytes:
img = self.load_image(ref)
img.thumbnail((max_size, max_size))
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
def cleanup(self):
if self.base_dir.exists():
shutil.rmtree(self.base_dir, ignore_errors=True)
|