Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import HfHubHTTPError | |
| from PIL import Image | |
| class LabelData: | |
| names: list[str] | |
| rating: list[np.int64] | |
| general: list[np.int64] | |
| character: list[np.int64] | |
| class ImageLabels: | |
| caption: str | |
| booru: str | |
| rating: dict[str, float] | |
| general: dict[str, float] | |
| character: dict[str, float] | |
| def load_labels_hf( | |
| repo_id: str, | |
| revision: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> LabelData: | |
| try: | |
| csv_path = hf_hub_download( | |
| repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token | |
| ) | |
| csv_path = Path(csv_path).resolve() | |
| except HfHubHTTPError as e: | |
| raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e | |
| df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) | |
| tag_data = LabelData( | |
| names=df["name"].tolist(), | |
| rating=list(np.where(df["category"] == 9)[0]), | |
| general=list(np.where(df["category"] == 0)[0]), | |
| character=list(np.where(df["category"] == 4)[0]), | |
| ) | |
| return tag_data | |
| def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
| # convert to RGB/RGBA if not already (deals with palette images etc.) | |
| if image.mode not in ["RGB", "RGBA"]: | |
| image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") | |
| # convert RGBA to RGB with white background | |
| if image.mode == "RGBA": | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") | |
| return image | |
| def pil_pad_square( | |
| image: Image.Image, | |
| fill: tuple[int, int, int] = (255, 255, 255), | |
| ) -> Image.Image: | |
| w, h = image.size | |
| # get the largest dimension so we can pad to a square | |
| px = max(image.size) | |
| # pad to square with white background | |
| canvas = Image.new("RGB", (px, px), fill) | |
| canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
| return canvas | |
| def preprocess_image( | |
| image: Image.Image, | |
| size_px: int | tuple[int, int], | |
| upscale: bool = True, | |
| ) -> Image.Image: | |
| """ | |
| Preprocess an image to be square and centered on a white background. | |
| """ | |
| if isinstance(size_px, int): | |
| size_px = (size_px, size_px) | |
| # ensure RGB and pad to square | |
| image = pil_ensure_rgb(image) | |
| image = pil_pad_square(image) | |
| # resize to target size | |
| if image.size[0] < size_px[0] or image.size[1] < size_px[1]: | |
| if upscale is False: | |
| raise ValueError("Image is smaller than target size, and upscaling is disabled") | |
| image = image.resize(size_px, Image.LANCZOS) | |
| if image.size[0] > size_px[0] or image.size[1] > size_px[1]: | |
| image.thumbnail(size_px, Image.BICUBIC) | |
| return image | |
| # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368 | |
| kaomojis = [ | |
| "0_0", | |
| "(o)_(o)", | |
| "+_+", | |
| "+_-", | |
| "._.", | |
| "<o>_<o>", | |
| "<|>_<|>", | |
| "=_=", | |
| ">_<", | |
| "3_3", | |
| "6_9", | |
| ">_o", | |
| "@_@", | |
| "^_^", | |
| "o_o", | |
| "u_u", | |
| "x_x", | |
| "|_|", | |
| "||_||", | |
| ] | |