Spaces:
Running
Running
| import json | |
| import math | |
| from dataclasses import dataclass, field | |
| from os import PathLike, cpu_count | |
| from pathlib import Path | |
| from typing import Any, Optional, TypeAlias | |
| import colorcet as cc | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import timm | |
| import torch | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from PIL import Image | |
| from timm.data import create_transform, resolve_data_config | |
| from timm.models import VisionTransformer | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from torchvision import transforms as T | |
| from .common import Heatmap, ImageLabels, LabelData, load_labels_hf, pil_ensure_rgb, pil_make_grid | |
| # working dir, either file parent dir or cwd if interactive | |
| work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve() | |
| temp_dir = work_dir.joinpath("temp") | |
| temp_dir.mkdir(exist_ok=True, parents=True) | |
| # model cache | |
| model_cache: dict[str, VisionTransformer] = {} | |
| transform_cache: dict[str, T.Compose] = {} | |
| # device to use | |
| torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class RGBtoBGR(nn.Module): | |
| def forward(self, x: Tensor) -> Tensor: | |
| if x.ndim == 4: | |
| return x[:, [2, 1, 0], :, :] | |
| return x[[2, 1, 0], :, :] | |
| def model_device(model: nn.Module) -> torch.device: | |
| return next(model.parameters()).device | |
| def load_model(repo_id: str) -> VisionTransformer: | |
| global model_cache | |
| if model_cache.get(repo_id, None) is None: | |
| # save model to cache | |
| model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch_device) | |
| return model_cache[repo_id] | |
| def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]: | |
| global transform_cache | |
| global model_cache | |
| if model_cache.get(repo_id, None) is None: | |
| # save model to cache | |
| model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval() | |
| model = model_cache[repo_id] | |
| if transform_cache.get(repo_id, None) is None: | |
| transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) | |
| # hack in the RGBtoBGR transform, save to cache | |
| transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()]) | |
| transform = transform_cache[repo_id] | |
| return model, transform | |
| def get_tags( | |
| probs: Tensor, | |
| labels: LabelData, | |
| gen_threshold: float, | |
| char_threshold: float, | |
| ): | |
| # Convert indices+probs to labels | |
| probs = list(zip(labels.names, probs.numpy())) | |
| # First 4 labels are actually ratings | |
| rating_labels = dict([probs[i] for i in labels.rating]) | |
| # General labels, pick any where prediction confidence > threshold | |
| gen_labels = [probs[i] for i in labels.general] | |
| gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) | |
| gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) | |
| # Character labels, pick any where prediction confidence > threshold | |
| char_labels = [probs[i] for i in labels.character] | |
| char_labels = dict([x for x in char_labels if x[1] > char_threshold]) | |
| char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) | |
| # Combine general and character labels, sort by confidence | |
| combined_names = [x for x in gen_labels] | |
| combined_names.extend([x for x in char_labels]) | |
| # Convert to a string suitable for use as a training caption | |
| caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)") | |
| booru = caption.replace("_", " ") | |
| return caption, booru, rating_labels, char_labels, gen_labels | |
| def render_heatmap( | |
| image: Tensor, | |
| gradients: Tensor, | |
| image_feats: Tensor, | |
| image_probs: Tensor, | |
| image_labels: list[str], | |
| cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71, | |
| pos_embed_dim: int = 784, | |
| image_size: tuple[int, int] = (448, 448), | |
| font_args: dict = { | |
| "fontFace": cv2.FONT_HERSHEY_SIMPLEX, | |
| "fontScale": 1, | |
| "color": (255, 255, 255), | |
| "thickness": 2, | |
| "lineType": cv2.LINE_AA, | |
| }, | |
| partial_rows: bool = True, | |
| ) -> tuple[list[Heatmap], Image.Image]: | |
| hmap_dim = int(math.sqrt(pos_embed_dim)) | |
| image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze() | |
| image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim) | |
| image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps)) | |
| image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1) | |
| # normalize to 0-1 | |
| image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1) | |
| # interpolate to input image size | |
| image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1) | |
| hmap_imgs: list[Heatmap] = [] | |
| for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()): | |
| image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) | |
| hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3] | |
| hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR) | |
| hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0) | |
| if tag is not None: | |
| cv2.putText(hmap_image, tag, (10, 30), **font_args) | |
| cv2.putText(hmap_image, f"{score:.3f}", org=(10, 60), **font_args) | |
| hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB)) | |
| hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil)) | |
| hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True) | |
| hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows) | |
| return hmap_imgs, hmap_grid | |
| def process_heatmap( | |
| model: VisionTransformer, | |
| image: Tensor, | |
| labels: LabelData, | |
| threshold: float = 0.5, | |
| partial_rows: bool = True, | |
| ) -> tuple[list[tuple[float, str, Image.Image]], Image.Image, ImageLabels]: | |
| torch_device = model_device(model) | |
| with torch.set_grad_enabled(True): | |
| features = model.forward_features(image.to(torch_device)) | |
| probs = model.forward_head(features) | |
| probs = F.sigmoid(probs).squeeze(0) | |
| probs_mask = probs > threshold | |
| heatmap_probs = probs[probs_mask] | |
| label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1) | |
| image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))] | |
| eye = torch.eye(heatmap_probs.shape[0], device=torch_device) | |
| grads = torch.autograd.grad( | |
| outputs=heatmap_probs, | |
| inputs=features, | |
| grad_outputs=eye, | |
| is_grads_batched=True, | |
| retain_graph=True, | |
| ) | |
| grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1) | |
| with torch.set_grad_enabled(False): | |
| hmap_imgs, hmap_grid = render_heatmap( | |
| image=image, | |
| gradients=grads, | |
| image_feats=features, | |
| image_probs=heatmap_probs, | |
| image_labels=image_labels, | |
| partial_rows=partial_rows, | |
| ) | |
| caption, booru, ratings, character, general = get_tags( | |
| probs=probs.cpu(), | |
| labels=labels, | |
| gen_threshold=threshold, | |
| char_threshold=threshold, | |
| ) | |
| labels = ImageLabels(caption, booru, ratings, general, character) | |
| return hmap_imgs, hmap_grid, labels | |