from dataclasses import dataclass from pathlib import Path from typing import Optional import numpy as np import pandas as pd import timm import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from PIL import Image from simple_parsing import field, parse_known_args from timm.data import create_transform, resolve_data_config from torch import Tensor, nn from torch.nn import functional as F import os import time from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from PIL import Image, UnidentifiedImageError from pathlib import Path from tqdm import tqdm @dataclass class ScriptOptions: image_folder: Path = "/workspace/ds/reddit" model: str = field(default="vit") gen_threshold: float = field(default=0.7) char_threshold: float = field(default=0.6) dream_model = AutoModelForCausalLM.from_pretrained( "moondream/moondream-2b-2025-04-14-4bit", trust_remote_code=True, device_map={"": "cuda"} ) dream_model.model.compile() torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") new_path = '/workspace/wdv3-timm' os.chdir(new_path) print(os.getcwd()) MODEL_REPO_MAP = { "vit": "SmilingWolf/wd-vit-tagger-v3", "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", "convnext": "SmilingWolf/wd-convnext-tagger-v3", } def pil_ensure_rgb(image: Image.Image) -> Image.Image: if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square(image: Image.Image) -> Image.Image: w, h = image.size px = max(image.size) canvas = Image.new("RGB", (px, px), (255, 255, 255)) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas @dataclass class LabelData: names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] def load_labels_hf( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> LabelData: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token ) csv_path = Path(csv_path).resolve() except HfHubHTTPError as e: raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data def get_tags( probs: Tensor, labels: LabelData, gen_threshold: float, char_threshold: float, ): probs = list(zip(labels.names, probs.numpy())) rating_labels = dict([probs[i] for i in labels.rating]) rating_labels = dict(sorted(rating_labels.items(), key=lambda item: item[1], reverse=True)) gen_labels = [probs[i] for i in labels.general] gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) char_labels = [probs[i] for i in labels.character] char_labels = dict([x for x in char_labels if x[1] > char_threshold]) char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) combined_names = [x for x in gen_labels] combined_names.extend([x for x in char_labels]) caption = ", ".join(combined_names) taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") caption = caption.replace("_", " ") caption += ", rating_" + next(iter(sorted(rating_labels, key=rating_labels.get, reverse=True)), '') return caption, taglist, rating_labels, char_labels, gen_labels def get_all_images(folder): count = 0 for path in folder.rglob('*'): if path.suffix.lower() in ('.jpeg', '.jpg', '.png'): count += 1 yield path def main(opts: ScriptOptions): repo_id = MODEL_REPO_MAP.get(opts.model) image_folder = Path(opts.image_folder).resolve() if not image_folder.is_dir(): raise NotADirectoryError(f"Image folder not found: {image_folder}") print(f"Loading model '{opts.model}' from '{repo_id}'...") model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval() state_dict = timm.models.load_state_dict_from_hf(repo_id) model.load_state_dict(state_dict) print("Loading tag list...") labels: LabelData = load_labels_hf(repo_id=repo_id) print("Creating data transform...") transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) image_paths = list(get_all_images(image_folder)) num_images = len(image_paths) for image_path in tqdm(image_paths, desc="Processing images"): txt_file = image_path.with_suffix('.txt') if txt_file.exists(): continue try: img_input: Image.Image = Image.open(image_path) img_input = pil_ensure_rgb(img_input) img_input = pil_pad_square(img_input) inputs: Tensor = transform(img_input).unsqueeze(0) inputs = inputs[:, [2, 1, 0]] with torch.inference_mode(): mdream_capt = dream_model.caption(img_input, length="normal")["caption"] mdream_capt = mdream_capt.replace("The image depicts ", "").replace("The image presents ", "").replace("The image features ", "").replace("The image portrays ", "").replace("The image is ", "").strip() if torch_device.type != "cpu": model = model.to(torch_device) inputs = inputs.to(torch_device) outputs = model.forward(inputs) outputs = F.sigmoid(outputs) if torch_device.type != "cpu": inputs = inputs.to("cpu") outputs = outputs.to("cpu") model = model.to("cpu") caption, taglist, ratings, character, general = get_tags( probs=outputs.squeeze(0), labels=labels, gen_threshold=opts.gen_threshold, char_threshold=opts.char_threshold, ) clean_name = image_path.stem clean_name = ' '.join(word for word in clean_name.split() if not word.startswith(('1', '2', '3', '4', '5', '6', '7', '8', '9', '0'))) tags_filename = str(image_path.with_suffix('.tag')) text_filename = str(image_path.with_suffix('.txt')) with open(tags_filename, 'w') as file_tag: file_tag.write(f"{caption}") with open(text_filename, 'w') as file_txt: file_txt.write(f"{mdream_capt} {caption}. {clean_name}") except (OSError, UnidentifiedImageError) as e: print(f"Error processing {image_path}: {str(e)}") continue print("Done!") if __name__ == "__main__": opts, _ = parse_known_args(ScriptOptions) main(opts)