Spaces:
Runtime error
Runtime error
| import math | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import HfHubHTTPError | |
| from PIL import Image | |
| from torch import Tensor, nn | |
| class Heatmap: | |
| label: str | |
| score: float | |
| image: Image.Image | |
| class LabelData: | |
| names: list[str] | |
| rating: list[np.int64] | |
| general: list[np.int64] | |
| character: list[np.int64] | |
| class ImageLabels: | |
| caption: str | |
| booru: str | |
| rating: dict[str, float] | |
| general: dict[str, float] | |
| character: dict[str, float] | |
| def load_labels_hf( | |
| repo_id: str, | |
| revision: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> LabelData: | |
| try: | |
| csv_path = hf_hub_download( | |
| repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token | |
| ) | |
| csv_path = Path(csv_path).resolve() | |
| except HfHubHTTPError as e: | |
| raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e | |
| df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) | |
| tag_data = LabelData( | |
| names=df["name"].tolist(), | |
| rating=list(np.where(df["category"] == 9)[0]), | |
| general=list(np.where(df["category"] == 0)[0]), | |
| character=list(np.where(df["category"] == 4)[0]), | |
| ) | |
| return tag_data | |
| def mcut_threshold(probs: np.ndarray) -> float: | |
| """ | |
| Maximum Cut Thresholding (MCut) | |
| Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy | |
| for Multi-label Classification. In 11th International Symposium, IDA 2012 | |
| (pp. 172-183). | |
| """ | |
| probs = probs[probs.argsort()[::-1]] | |
| diffs = probs[:-1] - probs[1:] | |
| idx = diffs.argmax() | |
| thresh = (probs[idx] + probs[idx + 1]) / 2 | |
| return float(thresh) | |
| def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
| # convert to RGB/RGBA if not already (deals with palette images etc.) | |
| if image.mode not in ["RGB", "RGBA"]: | |
| image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") | |
| # convert RGBA to RGB with white background | |
| if image.mode == "RGBA": | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") | |
| return image | |
| def pil_pad_square( | |
| image: Image.Image, | |
| fill: tuple[int, int, int] = (255, 255, 255), | |
| ) -> Image.Image: | |
| w, h = image.size | |
| # get the largest dimension so we can pad to a square | |
| px = max(image.size) | |
| # pad to square with white background | |
| canvas = Image.new("RGB", (px, px), fill) | |
| canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
| return canvas | |
| def preprocess_image( | |
| image: Image.Image, | |
| size_px: int | tuple[int, int], | |
| upscale: bool = True, | |
| ) -> Image.Image: | |
| """ | |
| Preprocess an image to be square and centered on a white background. | |
| """ | |
| if isinstance(size_px, int): | |
| size_px = (size_px, size_px) | |
| # ensure RGB and pad to square | |
| image = pil_ensure_rgb(image) | |
| image = pil_pad_square(image) | |
| # resize to target size | |
| if image.size[0] < size_px[0] or image.size[1] < size_px[1]: | |
| if upscale is False: | |
| raise ValueError("Image is smaller than target size, and upscaling is disabled") | |
| image = image.resize(size_px, Image.LANCZOS) | |
| if image.size[0] > size_px[0] or image.size[1] > size_px[1]: | |
| image.thumbnail(size_px, Image.BICUBIC) | |
| return image | |
| def pil_make_grid( | |
| images: list[Image.Image], | |
| max_cols: int = 8, | |
| padding: int = 4, | |
| bg_color: tuple[int, int, int] = (40, 42, 54), # dracula background color | |
| partial_rows: bool = True, | |
| ) -> Image.Image: | |
| n_cols = min(math.floor(math.sqrt(len(images))), max_cols) | |
| n_rows = math.ceil(len(images) / n_cols) | |
| # if the final row is not full and partial_rows is False, remove a row | |
| if n_cols * n_rows > len(images) and not partial_rows: | |
| n_rows -= 1 | |
| # assumes all images are same size | |
| image_width, image_height = images[0].size | |
| canvas_width = ((image_width + padding) * n_cols) + padding | |
| canvas_height = ((image_height + padding) * n_rows) + padding | |
| canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color) | |
| for i, img in enumerate(images): | |
| x = (i % n_cols) * (image_width + padding) + padding | |
| y = (i // n_cols) * (image_height + padding) + padding | |
| canvas.paste(img, (x, y)) | |
| return canvas | |
| # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368 | |
| kaomojis = [ | |
| "0_0", | |
| "(o)_(o)", | |
| "+_+", | |
| "+_-", | |
| "._.", | |
| "<o>_<o>", | |
| "<|>_<|>", | |
| "=_=", | |
| ">_<", | |
| "3_3", | |
| "6_9", | |
| ">_o", | |
| "@_@", | |
| "^_^", | |
| "o_o", | |
| "u_u", | |
| "x_x", | |
| "|_|", | |
| "||_||", | |
| ] | |