Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import argparse | |
| import logging | |
| from dataclasses import dataclass | |
| from os import PathLike | |
| from pathlib import Path | |
| from typing import Generator, Optional, Tuple | |
| import numpy as np | |
| import onnxruntime as rt | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import HfHubHTTPError | |
| from pandas import DataFrame, read_csv | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| # allowed extensions | |
| IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] | |
| # image input shape | |
| IMAGE_SIZE = 448 | |
| MODEL_VARIANTS: dict[str, str] = { | |
| "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", | |
| "convnext": "SmilingWolf/wd-convnext-tagger-v3", | |
| "vit": "SmilingWolf/wd-vit-tagger-v3", | |
| } | |
| class LabelData: | |
| names: list[str] | |
| rating: list[np.int64] | |
| general: list[np.int64] | |
| character: list[np.int64] | |
| class ImageLabels: | |
| caption: str | |
| booru: str | |
| rating: str | |
| general: dict[str, float] | |
| character: dict[str, float] | |
| ratings: dict[str, float] | |
| logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| ## Model loading functions | |
| def download_onnx( | |
| repo_id: str, | |
| filename: str = "model.onnx", | |
| revision: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> Path: | |
| if not filename.endswith(".onnx"): | |
| filename += ".onnx" | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) | |
| return Path(model_path).resolve() | |
| def create_session( | |
| repo_id: str, | |
| revision: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> rt.InferenceSession: | |
| model_path = download_onnx(repo_id, revision=revision, token=token) | |
| if not model_path.is_file(): | |
| model_path = model_path.joinpath("model.onnx") | |
| if not model_path.is_file(): | |
| raise FileNotFoundError(f"Model not found: {model_path}") | |
| model = rt.InferenceSession( | |
| str(model_path), | |
| providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"], | |
| ) | |
| return model | |
| ## Label loading function | |
| def load_labels_hf( | |
| repo_id: str, | |
| revision: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> LabelData: | |
| try: | |
| csv_path = hf_hub_download( | |
| repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token | |
| ) | |
| csv_path = Path(csv_path).resolve() | |
| except HfHubHTTPError as e: | |
| raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e | |
| df: DataFrame = read_csv(csv_path, usecols=["name", "category"]) | |
| tag_data = LabelData( | |
| names=df["name"].tolist(), | |
| rating=list(np.where(df["category"] == 9)[0]), | |
| general=list(np.where(df["category"] == 0)[0]), | |
| character=list(np.where(df["category"] == 4)[0]), | |
| ) | |
| return tag_data | |
| ## Image preprocessing functions | |
| def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
| # convert to RGB/RGBA if not already (deals with palette images etc.) | |
| if image.mode not in ["RGB", "RGBA"]: | |
| image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") | |
| # convert RGBA to RGB with white background | |
| if image.mode == "RGBA": | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") | |
| return image | |
| def pil_pad_square( | |
| image: Image.Image, | |
| fill: tuple[int, int, int] = (255, 255, 255), | |
| ) -> Image.Image: | |
| w, h = image.size | |
| # get the largest dimension so we can pad to a square | |
| px = max(image.size) | |
| # pad to square with white background | |
| canvas = Image.new("RGB", (px, px), fill) | |
| canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
| return canvas | |
| def preprocess_image( | |
| image: Image.Image, | |
| size_px: int | tuple[int, int], | |
| upscale: bool = True, | |
| ) -> Image.Image: | |
| """ | |
| Preprocess an image to be square and centered on a white background. | |
| """ | |
| if isinstance(size_px, int): | |
| size_px = (size_px, size_px) | |
| # ensure RGB and pad to square | |
| image = pil_ensure_rgb(image) | |
| image = pil_pad_square(image) | |
| # resize to target size | |
| if image.size[0] < size_px[0] or image.size[1] < size_px[1]: | |
| if upscale is False: | |
| raise ValueError("Image is smaller than target size, and upscaling is disabled") | |
| image = image.resize(size_px, Image.LANCZOS) | |
| if image.size[0] > size_px[0] or image.size[1] > size_px[1]: | |
| image.thumbnail(size_px, Image.BICUBIC) | |
| return image | |
| ## Dataset for DataLoader | |
| class ImageDataset(Dataset): | |
| def __init__(self, image_paths: list[Path], size_px: int = IMAGE_SIZE, upscale: bool = True): | |
| self.size_px = size_px | |
| self.upscale = upscale | |
| self.images = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS] | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| image_path: Path = self.images[idx] | |
| try: | |
| image = Image.open(image_path) | |
| image = preprocess_image(image, self.size_px, self.upscale) | |
| # turn into BGR24 numpy array of N,H,W,C since thats what these want | |
| image = image.convert("RGB").convert("BGR;24") | |
| image = np.array(image).astype(np.float32) | |
| except Exception as e: | |
| logging.exception(f"Could not load image from {image_path}, error: {e}") | |
| return None | |
| return {"image": image, "path": np.array(str(image_path).encode("utf-8"), dtype=np.bytes_)} | |
| def collate_fn_remove_corrupted(batch): | |
| """Collate function that allows to remove corrupted examples in the | |
| dataloader. It expects that the dataloader returns 'None' when that occurs. | |
| The 'None's in the batch are removed. | |
| """ | |
| # Filter out all the Nones (corrupted examples) | |
| batch = [x for x in batch if x is not None] | |
| if len(batch) == 0: | |
| return None | |
| return {k: np.array([x[k] for x in batch if x is not None]) for k in batch[0]} | |
| ## Main function | |
| class ImageLabeler: | |
| def __init__( | |
| self, | |
| repo_id: Optional[PathLike] = None, | |
| general_threshold: float = 0.35, | |
| character_threshold: float = 0.35, | |
| banned_tags: list[str] = [], | |
| ): | |
| self.repo_id = repo_id | |
| # create some object attributes for convenience | |
| self.general_threshold = general_threshold | |
| self.character_threshold = character_threshold | |
| self.banned_tags = banned_tags if banned_tags is not None else [] | |
| # actually load the model | |
| logging.info(f"Loading model from path: {self.repo_id}") | |
| self.model = create_session(self.repo_id) | |
| # Get input dimensions | |
| _, self.height, self.width, _ = self.model.get_inputs()[0].shape | |
| logging.info(f"Model loaded, input dimensions {self.height}x{self.width}") | |
| # load labels | |
| self.labels = load_labels_hf(self.repo_id) | |
| self.labels.general = [i for i in self.labels.general if i not in banned_tags] | |
| self.labels.character = [i for i in self.labels.character if i not in banned_tags] | |
| logging.info(f"Loaded labels from {self.repo_id}") | |
| def input_size(self) -> Tuple[int, int]: | |
| return (self.height, self.width) | |
| def input_name(self) -> str: | |
| return self.model.get_inputs()[0].name if self.model is not None else None | |
| def output_name(self) -> str: | |
| return self.model.get_outputs()[0].name if self.model is not None else None | |
| def label_images(self, images: np.ndarray) -> ImageLabels: | |
| # Run the ONNX model | |
| probs: np.ndarray = self.model.run([self.output_name], {self.input_name: images})[0] | |
| # Convert to labels | |
| results = [] | |
| for sample in list(probs): | |
| labels = list(zip(self.labels.names, sample.astype(float))) | |
| # First 4 labels are actually ratings: pick one with argmax | |
| rating_labels = dict([labels[i] for i in self.labels.rating]) | |
| rating = max(rating_labels, key=rating_labels.get) | |
| # General labels, pick any where prediction confidence > threshold | |
| gen_labels = [labels[i] for i in self.labels.general] | |
| gen_labels = dict([x for x in gen_labels if x[1] > self.general_threshold]) | |
| gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) | |
| # Character labels, pick any where prediction confidence > threshold | |
| char_labels = [labels[i] for i in self.labels.character] | |
| char_labels = dict([x for x in char_labels if x[1] > self.character_threshold]) | |
| char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) | |
| # Combine general and character labels, sort by confidence | |
| combined_names = [x for x in gen_labels] | |
| combined_names.extend([x for x in char_labels]) | |
| # Convert to a string suitable for use as a training caption | |
| caption = ", ".join(combined_names) | |
| booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") | |
| # return output | |
| results.append( | |
| ImageLabels( | |
| caption=caption, | |
| booru=booru, | |
| rating=rating, | |
| general=gen_labels, | |
| character=char_labels, | |
| ratings=rating_labels, | |
| ) | |
| ) | |
| return results | |
| def __call__(self, images: list[Image.Image]) -> Generator[ImageLabels, None, None]: | |
| for x in images: | |
| yield self.label_images(x) | |
| def main(args): | |
| images_dir: Path = Path(args.images_dir).resolve() | |
| if not images_dir.is_dir(): | |
| raise FileNotFoundError(f"Directory not found: {images_dir}") | |
| variant: str = args.variant | |
| recursive: bool = args.recursive or False | |
| banned_tags: set[str] = set(args.banned_tags.split(",")) | |
| caption_extension: str = str(args.caption_extension).lower() | |
| print_freqs: bool = args.print_freqs or False | |
| num_workers: int = args.num_workers | |
| batch_size: int = args.batch_size | |
| remove_underscore: bool = args.remove_underscore or False | |
| general_threshold: float = args.general_threshold or args.thresh | |
| character_threshold: float = args.character_threshold or args.thresh | |
| debug: bool = args.debug or False | |
| # turn base model into a repo id and model path | |
| repo_id: str = MODEL_VARIANTS.get(variant, None) | |
| if repo_id is None: | |
| raise ValueError(f"Unknown base model '{variant}'") | |
| # instantiate the dataset | |
| print(f"Loading images from {images_dir}...", end=" ") | |
| if recursive is True: | |
| image_paths = [p for p in images_dir.rglob("**/*") if p.suffix.lower() in IMAGE_EXTENSIONS] | |
| else: | |
| image_paths = [p for p in images_dir.glob("*") if p.suffix.lower() in IMAGE_EXTENSIONS] | |
| n_images = len(image_paths) | |
| print(f"found {n_images} images to process, creating DataLoader...") | |
| # sort by filename if we have a small number of images | |
| if n_images < 10000: | |
| image_paths = sorted(image_paths, key=lambda x: x.stem) | |
| dataset = ImageDataset(image_paths) | |
| # Create the data loader | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn_remove_corrupted, | |
| drop_last=False, | |
| prefetch_factor=3, | |
| ) | |
| # Create the image labeler | |
| labeler: ImageLabeler = ImageLabeler( | |
| repo_id=repo_id, | |
| character_threshold=character_threshold, | |
| general_threshold=general_threshold, | |
| banned_tags=banned_tags, | |
| ) | |
| # object to save tag frequencies | |
| tag_freqs = {} | |
| # iterate | |
| for batch in tqdm(dataloader, ncols=100, unit="image", unit_scale=batch_size): | |
| images = batch["image"] | |
| paths = batch["path"] | |
| # label the images | |
| batch_labels = labeler.label_images(images) | |
| # save the labels | |
| for image_labels, image_path in zip(batch_labels, paths): | |
| if isinstance(image_path, (np.bytes_, bytes)): | |
| image_path = Path(image_path.decode("utf-8")) | |
| # save the labels | |
| caption = image_labels.caption | |
| if remove_underscore is True: | |
| caption = caption.replace("_", " ") | |
| Path(image_path).with_suffix(caption_extension).write_text(caption + "\n", encoding="utf-8") | |
| # save the tag frequencies | |
| if print_freqs is True: | |
| for tag in caption.split(", "): | |
| if tag in banned_tags: | |
| continue | |
| if tag not in tag_freqs: | |
| tag_freqs[tag] = 0 | |
| tag_freqs[tag] += 1 | |
| # debug | |
| if debug is True: | |
| print( | |
| f"{image_path}:" | |
| + f"\n Character tags: {image_labels.character}" | |
| + f"\n General tags: {image_labels.general}" | |
| ) | |
| if print_freqs: | |
| sorted_tags = sorted(tag_freqs.items(), key=lambda x: x[1], reverse=True) | |
| print("\nTag frequencies:") | |
| for tag, freq in sorted_tags: | |
| print(f"{tag}: {freq}") | |
| print("done!") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "images_dir", | |
| type=str, | |
| help="directory to tag image files in", | |
| ) | |
| parser.add_argument( | |
| "--variant", | |
| type=str, | |
| default="swinv2", | |
| help="name of base model to use (one of 'swinv2', 'convnext', 'vit')", | |
| ) | |
| parser.add_argument( | |
| "--num_workers", | |
| type=int, | |
| default=4, | |
| help="number of threads to use in Torch DataLoader (4 should be plenty)", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=1, | |
| help="batch size for Torch DataLoader (use 1 for cpu, 4-32 for gpu)", | |
| ) | |
| parser.add_argument( | |
| "--caption_extension", | |
| type=str, | |
| default=".txt", | |
| help="extension of caption files to write (e.g. '.txt', '.caption')", | |
| ) | |
| parser.add_argument( | |
| "--thresh", | |
| type=float, | |
| default=0.35, | |
| help="confidence threshold for adding tags", | |
| ) | |
| parser.add_argument( | |
| "--general_threshold", | |
| type=float, | |
| default=None, | |
| help="confidence threshold for general tags - defaults to --thresh", | |
| ) | |
| parser.add_argument( | |
| "--character_threshold", | |
| type=float, | |
| default=None, | |
| help="confidence threshold for character tags - defaults to --thresh", | |
| ) | |
| parser.add_argument( | |
| "--recursive", | |
| action="store_true", | |
| help="whether to recurse into subdirectories of images_dir", | |
| ) | |
| parser.add_argument( | |
| "--remove_underscore", | |
| action="store_true", | |
| help="whether to remove underscores from tags (e.g. 'long_hair' -> 'long hair')", | |
| ) | |
| parser.add_argument( | |
| "--debug", | |
| action="store_true", | |
| help="enable debug logging mode", | |
| ) | |
| parser.add_argument( | |
| "--banned_tags", | |
| type=str, | |
| default="", | |
| help="tags to filter out (comma-separated)", | |
| ) | |
| parser.add_argument( | |
| "--print_freqs", | |
| action="store_true", | |
| help="Print overall tag frequencies at the end", | |
| ) | |
| args = parser.parse_args() | |
| if args.images_dir is None: | |
| args.images_dir = Path.cwd().joinpath("temp/test") | |
| main(args) | |