Ryan2219's picture
Upload 70 files
e1ced8e verified
from __future__ import annotations
import base64
import io
import os
import shutil
import uuid
from pathlib import Path
from PIL import Image
from state import ImageRef
class ImageStore:
"""Disk-based image manager. LangGraph state only carries lightweight
``ImageRef`` dicts; all heavy image bytes live on disk."""
def __init__(self, base_dir: str):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self._pages_dir = self.base_dir / "pages"
self._crops_dir = self.base_dir / "crops"
self._annotated_dir = self.base_dir / "annotated"
for d in (self._pages_dir, self._crops_dir, self._annotated_dir):
d.mkdir(exist_ok=True)
# ------------------------------------------------------------------
# Save helpers
# ------------------------------------------------------------------
def save_page_image(self, page_num: int, image_bytes: bytes) -> ImageRef:
img = Image.open(io.BytesIO(image_bytes))
fname = f"page_{page_num}.png"
path = self._pages_dir / fname
img.save(str(path), format="PNG")
return ImageRef(
id=f"page_{page_num}",
path=str(path),
label=f"Page {page_num} (full page)",
page_num=page_num,
crop_type="full_page",
width=img.width,
height=img.height,
)
def save_crop(
self,
page_num: int,
crop_id: str,
image: Image.Image,
label: str,
) -> ImageRef:
fname = f"page_{page_num}_{crop_id}.png"
path = self._crops_dir / fname
image.save(str(path), format="PNG")
return ImageRef(
id=f"page_{page_num}_{crop_id}",
path=str(path),
label=label,
page_num=page_num,
crop_type="crop",
width=image.width,
height=image.height,
)
def save_annotated(
self,
source_ref: ImageRef,
annotated_image: Image.Image,
) -> ImageRef:
ann_id = f"{source_ref['id']}_ann"
fname = f"{ann_id}.png"
path = self._annotated_dir / fname
annotated_image.save(str(path), format="PNG")
return ImageRef(
id=ann_id,
path=str(path),
label=f"{source_ref['label']} [annotated]",
page_num=source_ref["page_num"],
crop_type="annotated",
width=annotated_image.width,
height=annotated_image.height,
)
# ------------------------------------------------------------------
# Load helpers
# ------------------------------------------------------------------
def load_image(self, ref: ImageRef) -> Image.Image:
return Image.open(ref["path"])
def load_bytes(self, ref: ImageRef) -> bytes:
with open(ref["path"], "rb") as f:
return f.read()
def get_page_image_path(self, page_num: int) -> str:
return str(self._pages_dir / f"page_{page_num}.png")
def load_page_bytes(self, page_num: int) -> bytes:
path = self.get_page_image_path(page_num)
with open(path, "rb") as f:
return f.read()
# ------------------------------------------------------------------
# Format conversions for different model APIs
# ------------------------------------------------------------------
def to_gemini_part(self, ref: ImageRef):
"""Return a ``google.genai.types.Part`` for Gemini multimodal prompts."""
from google.genai import types
return types.Part.from_bytes(
data=self.load_bytes(ref),
mime_type="image/png",
)
def to_openai_base64(self, ref: ImageRef) -> dict:
"""Return an OpenAI-compatible image content block (base64 data URI)."""
b64 = base64.b64encode(self.load_bytes(ref)).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64}"},
}
def create_thumbnail(self, ref: ImageRef, max_size: int = 400) -> bytes:
img = self.load_image(ref)
img.thumbnail((max_size, max_size))
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
def cleanup(self):
if self.base_dir.exists():
shutil.rmtree(self.base_dir, ignore_errors=True)