sdxs / src /caption.py
recoilme's picture
0806
50e2534
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import timm
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
from simple_parsing import field, parse_known_args
from timm.data import create_transform, resolve_data_config
from torch import Tensor, nn
from torch.nn import functional as F
import os
import time
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from PIL import Image, UnidentifiedImageError
from pathlib import Path
from tqdm import tqdm
@dataclass
class ScriptOptions:
image_folder: Path = "/workspace/ds/reddit"
model: str = field(default="vit")
gen_threshold: float = field(default=0.7)
char_threshold: float = field(default=0.6)
dream_model = AutoModelForCausalLM.from_pretrained(
"moondream/moondream-2b-2025-04-14-4bit",
trust_remote_code=True,
device_map={"": "cuda"}
)
dream_model.model.compile()
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
new_path = '/workspace/wdv3-timm'
os.chdir(new_path)
print(os.getcwd())
MODEL_REPO_MAP = {
"vit": "SmilingWolf/wd-vit-tagger-v3",
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
"convnext": "SmilingWolf/wd-convnext-tagger-v3",
}
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
if image.mode == "RGBA":
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
def pil_pad_square(image: Image.Image) -> Image.Image:
w, h = image.size
px = max(image.size)
canvas = Image.new("RGB", (px, px), (255, 255, 255))
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
return canvas
@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]
def load_labels_hf(
repo_id: str,
revision: Optional[str] = None,
token: Optional[str] = None,
) -> LabelData:
try:
csv_path = hf_hub_download(
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
)
csv_path = Path(csv_path).resolve()
except HfHubHTTPError as e:
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
tag_data = LabelData(
names=df["name"].tolist(),
rating=list(np.where(df["category"] == 9)[0]),
general=list(np.where(df["category"] == 0)[0]),
character=list(np.where(df["category"] == 4)[0]),
)
return tag_data
def get_tags(
probs: Tensor,
labels: LabelData,
gen_threshold: float,
char_threshold: float,
):
probs = list(zip(labels.names, probs.numpy()))
rating_labels = dict([probs[i] for i in labels.rating])
rating_labels = dict(sorted(rating_labels.items(), key=lambda item: item[1], reverse=True))
gen_labels = [probs[i] for i in labels.general]
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
char_labels = [probs[i] for i in labels.character]
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
combined_names = [x for x in gen_labels]
combined_names.extend([x for x in char_labels])
caption = ", ".join(combined_names)
taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
caption = caption.replace("_", " ")
caption += ", rating_" + next(iter(sorted(rating_labels, key=rating_labels.get, reverse=True)), '')
return caption, taglist, rating_labels, char_labels, gen_labels
def get_all_images(folder):
count = 0
for path in folder.rglob('*'):
if path.suffix.lower() in ('.jpeg', '.jpg', '.png'):
count += 1
yield path
def main(opts: ScriptOptions):
repo_id = MODEL_REPO_MAP.get(opts.model)
image_folder = Path(opts.image_folder).resolve()
if not image_folder.is_dir():
raise NotADirectoryError(f"Image folder not found: {image_folder}")
print(f"Loading model '{opts.model}' from '{repo_id}'...")
model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
state_dict = timm.models.load_state_dict_from_hf(repo_id)
model.load_state_dict(state_dict)
print("Loading tag list...")
labels: LabelData = load_labels_hf(repo_id=repo_id)
print("Creating data transform...")
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
image_paths = list(get_all_images(image_folder))
num_images = len(image_paths)
for image_path in tqdm(image_paths, desc="Processing images"):
txt_file = image_path.with_suffix('.txt')
if txt_file.exists():
continue
try:
img_input: Image.Image = Image.open(image_path)
img_input = pil_ensure_rgb(img_input)
img_input = pil_pad_square(img_input)
inputs: Tensor = transform(img_input).unsqueeze(0)
inputs = inputs[:, [2, 1, 0]]
with torch.inference_mode():
mdream_capt = dream_model.caption(img_input, length="normal")["caption"]
mdream_capt = mdream_capt.replace("The image depicts ", "").replace("The image presents ", "").replace("The image features ", "").replace("The image portrays ", "").replace("The image is ", "").strip()
if torch_device.type != "cpu":
model = model.to(torch_device)
inputs = inputs.to(torch_device)
outputs = model.forward(inputs)
outputs = F.sigmoid(outputs)
if torch_device.type != "cpu":
inputs = inputs.to("cpu")
outputs = outputs.to("cpu")
model = model.to("cpu")
caption, taglist, ratings, character, general = get_tags(
probs=outputs.squeeze(0),
labels=labels,
gen_threshold=opts.gen_threshold,
char_threshold=opts.char_threshold,
)
clean_name = image_path.stem
clean_name = ' '.join(word for word in clean_name.split() if not word.startswith(('1', '2', '3', '4', '5', '6', '7', '8', '9', '0')))
tags_filename = str(image_path.with_suffix('.tag'))
text_filename = str(image_path.with_suffix('.txt'))
with open(tags_filename, 'w') as file_tag:
file_tag.write(f"{caption}")
with open(text_filename, 'w') as file_txt:
file_txt.write(f"{mdream_capt} {caption}. {clean_name}")
except (OSError, UnidentifiedImageError) as e:
print(f"Error processing {image_path}: {str(e)}")
continue
print("Done!")
if __name__ == "__main__":
opts, _ = parse_known_args(ScriptOptions)
main(opts)