| from pathlib import Path |
| from string import ascii_letters, digits, punctuation |
|
|
| import numpy as np |
| import torch |
| from einops import rearrange |
| from jaxtyping import Float |
| from PIL import Image, ImageDraw, ImageFont |
| from torch import Tensor |
|
|
| from .layout import vcat |
|
|
| EXPECTED_CHARACTERS = digits + punctuation + ascii_letters |
|
|
|
|
| def draw_label( |
| text: str, |
| font: Path, |
| font_size: int, |
| device: torch.device = torch.device("cpu"), |
| ) -> Float[Tensor, "3 height width"]: |
| """Draw a black label on a white background with no border.""" |
| try: |
| font = ImageFont.truetype(str(font), font_size) |
| except OSError: |
| font = ImageFont.load_default() |
| left, _, right, _ = font.getbbox(text) |
| width = right - left |
| _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) |
| height = bottom - top |
| image = Image.new("RGB", (width, height), color="white") |
| draw = ImageDraw.Draw(image) |
| draw.text((0, 0), text, font=font, fill="black") |
| image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) |
| return rearrange(image, "h w c -> c h w") |
|
|
|
|
| def add_label( |
| image: Float[Tensor, "3 width height"], |
| label: str, |
| font: Path = Path("assets/Inter-Regular.otf"), |
| font_size: int = 24, |
| ) -> Float[Tensor, "3 width_with_label height_with_label"]: |
| return vcat( |
| draw_label(label, font, font_size, image.device), |
| image, |
| align="left", |
| gap=4, |
| ) |
|
|