| | |
| | import argparse |
| | import logging |
| | from dataclasses import dataclass |
| | from os import PathLike |
| | from pathlib import Path |
| | from typing import Generator, Optional, Tuple |
| |
|
| | import numpy as np |
| | import onnxruntime as rt |
| | from huggingface_hub import hf_hub_download |
| | from huggingface_hub.utils import HfHubHTTPError |
| | from pandas import DataFrame, read_csv |
| | from PIL import Image |
| | from torch.utils.data import DataLoader, Dataset |
| | from tqdm import tqdm |
| |
|
| | |
| | IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] |
| | |
| | IMAGE_SIZE = 448 |
| |
|
| | MODEL_VARIANTS: dict[str, str] = { |
| | "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", |
| | "convnext": "SmilingWolf/wd-convnext-tagger-v3", |
| | "vit": "SmilingWolf/wd-vit-tagger-v3", |
| | } |
| |
|
| |
|
| | @dataclass |
| | class LabelData: |
| | names: list[str] |
| | rating: list[np.int64] |
| | general: list[np.int64] |
| | character: list[np.int64] |
| |
|
| |
|
| | @dataclass |
| | class ImageLabels: |
| | caption: str |
| | booru: str |
| | rating: str |
| | general: dict[str, float] |
| | character: dict[str, float] |
| | ratings: dict[str, float] |
| |
|
| |
|
| | logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| | logger = logging.getLogger() |
| | logger.setLevel(logging.INFO) |
| |
|
| |
|
| | |
| | def download_onnx( |
| | repo_id: str, |
| | filename: str = "model.onnx", |
| | revision: Optional[str] = None, |
| | token: Optional[str] = None, |
| | ) -> Path: |
| | if not filename.endswith(".onnx"): |
| | filename += ".onnx" |
| |
|
| | model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) |
| | return Path(model_path).resolve() |
| |
|
| |
|
| | def create_session( |
| | repo_id: str, |
| | revision: Optional[str] = None, |
| | token: Optional[str] = None, |
| | ) -> rt.InferenceSession: |
| | model_path = download_onnx(repo_id, revision=revision, token=token) |
| | if not model_path.is_file(): |
| | model_path = model_path.joinpath("model.onnx") |
| | if not model_path.is_file(): |
| | raise FileNotFoundError(f"Model not found: {model_path}") |
| |
|
| | model = rt.InferenceSession( |
| | str(model_path), |
| | providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"], |
| | ) |
| | return model |
| |
|
| |
|
| | |
| | def load_labels_hf( |
| | repo_id: str, |
| | revision: Optional[str] = None, |
| | token: Optional[str] = None, |
| | ) -> LabelData: |
| | try: |
| | csv_path = hf_hub_download( |
| | repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token |
| | ) |
| | csv_path = Path(csv_path).resolve() |
| | except HfHubHTTPError as e: |
| | raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e |
| |
|
| | df: DataFrame = read_csv(csv_path, usecols=["name", "category"]) |
| | tag_data = LabelData( |
| | names=df["name"].tolist(), |
| | rating=list(np.where(df["category"] == 9)[0]), |
| | general=list(np.where(df["category"] == 0)[0]), |
| | character=list(np.where(df["category"] == 4)[0]), |
| | ) |
| |
|
| | return tag_data |
| |
|
| |
|
| | |
| | def pil_ensure_rgb(image: Image.Image) -> Image.Image: |
| | |
| | if image.mode not in ["RGB", "RGBA"]: |
| | image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") |
| | |
| | if image.mode == "RGBA": |
| | canvas = Image.new("RGBA", image.size, (255, 255, 255)) |
| | canvas.alpha_composite(image) |
| | image = canvas.convert("RGB") |
| | return image |
| |
|
| |
|
| | def pil_pad_square( |
| | image: Image.Image, |
| | fill: tuple[int, int, int] = (255, 255, 255), |
| | ) -> Image.Image: |
| | w, h = image.size |
| | |
| | px = max(image.size) |
| | |
| | canvas = Image.new("RGB", (px, px), fill) |
| | canvas.paste(image, ((px - w) // 2, (px - h) // 2)) |
| | return canvas |
| |
|
| |
|
| | def preprocess_image( |
| | image: Image.Image, |
| | size_px: int | tuple[int, int], |
| | upscale: bool = True, |
| | ) -> Image.Image: |
| | """ |
| | Preprocess an image to be square and centered on a white background. |
| | """ |
| | if isinstance(size_px, int): |
| | size_px = (size_px, size_px) |
| |
|
| | |
| | image = pil_ensure_rgb(image) |
| | image = pil_pad_square(image) |
| |
|
| | |
| | if image.size[0] < size_px[0] or image.size[1] < size_px[1]: |
| | if upscale is False: |
| | raise ValueError("Image is smaller than target size, and upscaling is disabled") |
| | image = image.resize(size_px, Image.LANCZOS) |
| | if image.size[0] > size_px[0] or image.size[1] > size_px[1]: |
| | image.thumbnail(size_px, Image.BICUBIC) |
| |
|
| | return image |
| |
|
| |
|
| | |
| | class ImageDataset(Dataset): |
| | def __init__(self, image_paths: list[Path], size_px: int = IMAGE_SIZE, upscale: bool = True): |
| | self.size_px = size_px |
| | self.upscale = upscale |
| | self.images = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS] |
| |
|
| | def __len__(self): |
| | return len(self.images) |
| |
|
| | def __getitem__(self, idx): |
| | image_path: Path = self.images[idx] |
| | try: |
| | image = Image.open(image_path) |
| | image = preprocess_image(image, self.size_px, self.upscale) |
| | |
| | image = image.convert("RGB").convert("BGR;24") |
| | image = np.array(image).astype(np.float32) |
| | except Exception as e: |
| | logging.exception(f"Could not load image from {image_path}, error: {e}") |
| | return None |
| |
|
| | return {"image": image, "path": np.array(str(image_path).encode("utf-8"), dtype=np.bytes_)} |
| |
|
| |
|
| | def collate_fn_remove_corrupted(batch): |
| | """Collate function that allows to remove corrupted examples in the |
| | dataloader. It expects that the dataloader returns 'None' when that occurs. |
| | The 'None's in the batch are removed. |
| | """ |
| | |
| | batch = [x for x in batch if x is not None] |
| | if len(batch) == 0: |
| | return None |
| | return {k: np.array([x[k] for x in batch if x is not None]) for k in batch[0]} |
| |
|
| |
|
| | |
| | class ImageLabeler: |
| | def __init__( |
| | self, |
| | repo_id: Optional[PathLike] = None, |
| | general_threshold: float = 0.35, |
| | character_threshold: float = 0.35, |
| | banned_tags: list[str] = [], |
| | ): |
| | self.repo_id = repo_id |
| |
|
| | |
| | self.general_threshold = general_threshold |
| | self.character_threshold = character_threshold |
| | self.banned_tags = banned_tags if banned_tags is not None else [] |
| |
|
| | |
| | logging.info(f"Loading model from path: {self.repo_id}") |
| | self.model = create_session(self.repo_id) |
| |
|
| | |
| | _, self.height, self.width, _ = self.model.get_inputs()[0].shape |
| | logging.info(f"Model loaded, input dimensions {self.height}x{self.width}") |
| |
|
| | |
| | self.labels = load_labels_hf(self.repo_id) |
| | self.labels.general = [i for i in self.labels.general if i not in banned_tags] |
| | self.labels.character = [i for i in self.labels.character if i not in banned_tags] |
| | logging.info(f"Loaded labels from {self.repo_id}") |
| |
|
| | @property |
| | def input_size(self) -> Tuple[int, int]: |
| | return (self.height, self.width) |
| |
|
| | @property |
| | def input_name(self) -> str: |
| | return self.model.get_inputs()[0].name if self.model is not None else None |
| |
|
| | @property |
| | def output_name(self) -> str: |
| | return self.model.get_outputs()[0].name if self.model is not None else None |
| |
|
| | def label_images(self, images: np.ndarray) -> ImageLabels: |
| | |
| | probs: np.ndarray = self.model.run([self.output_name], {self.input_name: images})[0] |
| |
|
| | |
| | results = [] |
| | for sample in list(probs): |
| | labels = list(zip(self.labels.names, sample.astype(float))) |
| |
|
| | |
| | rating_labels = dict([labels[i] for i in self.labels.rating]) |
| | rating = max(rating_labels, key=rating_labels.get) |
| |
|
| | |
| | gen_labels = [labels[i] for i in self.labels.general] |
| | gen_labels = dict([x for x in gen_labels if x[1] > self.general_threshold]) |
| | gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) |
| |
|
| | |
| | char_labels = [labels[i] for i in self.labels.character] |
| | char_labels = dict([x for x in char_labels if x[1] > self.character_threshold]) |
| | char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) |
| |
|
| | |
| | combined_names = [x for x in gen_labels] |
| | combined_names.extend([x for x in char_labels]) |
| |
|
| | |
| | caption = ", ".join(combined_names) |
| | booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") |
| |
|
| | |
| | results.append( |
| | ImageLabels( |
| | caption=caption, |
| | booru=booru, |
| | rating=rating, |
| | general=gen_labels, |
| | character=char_labels, |
| | ratings=rating_labels, |
| | ) |
| | ) |
| |
|
| | return results |
| |
|
| | def __call__(self, images: list[Image.Image]) -> Generator[ImageLabels, None, None]: |
| | for x in images: |
| | yield self.label_images(x) |
| |
|
| |
|
| | def main(args): |
| | images_dir: Path = Path(args.images_dir).resolve() |
| | if not images_dir.is_dir(): |
| | raise FileNotFoundError(f"Directory not found: {images_dir}") |
| |
|
| | variant: str = args.variant |
| | recursive: bool = args.recursive or False |
| | banned_tags: set[str] = set(args.banned_tags.split(",")) |
| | caption_extension: str = str(args.caption_extension).lower() |
| | print_freqs: bool = args.print_freqs or False |
| | num_workers: int = args.num_workers |
| | batch_size: int = args.batch_size |
| |
|
| | remove_underscore: bool = args.remove_underscore or False |
| | general_threshold: float = args.general_threshold or args.thresh |
| | character_threshold: float = args.character_threshold or args.thresh |
| | debug: bool = args.debug or False |
| |
|
| | |
| | repo_id: str = MODEL_VARIANTS.get(variant, None) |
| | if repo_id is None: |
| | raise ValueError(f"Unknown base model '{variant}'") |
| |
|
| | |
| | print(f"Loading images from {images_dir}...", end=" ") |
| | if recursive is True: |
| | image_paths = [p for p in images_dir.rglob("**/*") if p.suffix.lower() in IMAGE_EXTENSIONS] |
| | else: |
| | image_paths = [p for p in images_dir.glob("*") if p.suffix.lower() in IMAGE_EXTENSIONS] |
| |
|
| | n_images = len(image_paths) |
| | print(f"found {n_images} images to process, creating DataLoader...") |
| | |
| | if n_images < 10000: |
| | image_paths = sorted(image_paths, key=lambda x: x.stem) |
| | dataset = ImageDataset(image_paths) |
| |
|
| | |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | collate_fn=collate_fn_remove_corrupted, |
| | drop_last=False, |
| | prefetch_factor=3, |
| | ) |
| |
|
| | |
| | labeler: ImageLabeler = ImageLabeler( |
| | repo_id=repo_id, |
| | character_threshold=character_threshold, |
| | general_threshold=general_threshold, |
| | banned_tags=banned_tags, |
| | ) |
| |
|
| | |
| | tag_freqs = {} |
| |
|
| | |
| | for batch in tqdm(dataloader, ncols=100, unit="image", unit_scale=batch_size): |
| | images = batch["image"] |
| | paths = batch["path"] |
| |
|
| | |
| | batch_labels = labeler.label_images(images) |
| |
|
| | |
| | for image_labels, image_path in zip(batch_labels, paths): |
| | if isinstance(image_path, (np.bytes_, bytes)): |
| | image_path = Path(image_path.decode("utf-8")) |
| |
|
| | |
| | caption = image_labels.caption |
| | if remove_underscore is True: |
| | caption = caption.replace("_", " ") |
| | Path(image_path).with_suffix(caption_extension).write_text(caption + "\n", encoding="utf-8") |
| |
|
| | |
| | if print_freqs is True: |
| | for tag in caption.split(", "): |
| | if tag in banned_tags: |
| | continue |
| | if tag not in tag_freqs: |
| | tag_freqs[tag] = 0 |
| | tag_freqs[tag] += 1 |
| |
|
| | |
| | if debug is True: |
| | print( |
| | f"{image_path}:" |
| | + f"\n Character tags: {image_labels.character}" |
| | + f"\n General tags: {image_labels.general}" |
| | ) |
| |
|
| | if print_freqs: |
| | sorted_tags = sorted(tag_freqs.items(), key=lambda x: x[1], reverse=True) |
| | print("\nTag frequencies:") |
| | for tag, freq in sorted_tags: |
| | print(f"{tag}: {freq}") |
| |
|
| | print("done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "images_dir", |
| | type=str, |
| | help="directory to tag image files in", |
| | ) |
| | parser.add_argument( |
| | "--variant", |
| | type=str, |
| | default="swinv2", |
| | help="name of base model to use (one of 'swinv2', 'convnext', 'vit')", |
| | ) |
| | parser.add_argument( |
| | "--num_workers", |
| | type=int, |
| | default=4, |
| | help="number of threads to use in Torch DataLoader (4 should be plenty)", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=1, |
| | help="batch size for Torch DataLoader (use 1 for cpu, 4-32 for gpu)", |
| | ) |
| | parser.add_argument( |
| | "--caption_extension", |
| | type=str, |
| | default=".txt", |
| | help="extension of caption files to write (e.g. '.txt', '.caption')", |
| | ) |
| | parser.add_argument( |
| | "--thresh", |
| | type=float, |
| | default=0.35, |
| | help="confidence threshold for adding tags", |
| | ) |
| | parser.add_argument( |
| | "--general_threshold", |
| | type=float, |
| | default=None, |
| | help="confidence threshold for general tags - defaults to --thresh", |
| | ) |
| | parser.add_argument( |
| | "--character_threshold", |
| | type=float, |
| | default=None, |
| | help="confidence threshold for character tags - defaults to --thresh", |
| | ) |
| | parser.add_argument( |
| | "--recursive", |
| | action="store_true", |
| | help="whether to recurse into subdirectories of images_dir", |
| | ) |
| | parser.add_argument( |
| | "--remove_underscore", |
| | action="store_true", |
| | help="whether to remove underscores from tags (e.g. 'long_hair' -> 'long hair')", |
| | ) |
| | parser.add_argument( |
| | "--debug", |
| | action="store_true", |
| | help="enable debug logging mode", |
| | ) |
| | parser.add_argument( |
| | "--banned_tags", |
| | type=str, |
| | default="", |
| | help="tags to filter out (comma-separated)", |
| | ) |
| | parser.add_argument( |
| | "--print_freqs", |
| | action="store_true", |
| | help="Print overall tag frequencies at the end", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | if args.images_dir is None: |
| | args.images_dir = Path.cwd().joinpath("temp/test") |
| |
|
| | main(args) |
| |
|