Plana-Archive commited on
Commit
f19179c
·
verified ·
1 Parent(s): 658d9e6

Upload wd-tagger-heatmap-more-models/tagger/model.py with huggingface_hub

Browse files
wd-tagger-heatmap-more-models/tagger/model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+
4
+ import colorcet as cc
5
+ import cv2
6
+ import numpy as np
7
+ import timm
8
+ import torch
9
+ from PIL import Image
10
+ from matplotlib.colors import LinearSegmentedColormap
11
+ from timm.data import create_transform, resolve_data_config
12
+ from timm.models import VisionTransformer
13
+ from torch import Tensor, nn
14
+ from torch.nn import functional as F
15
+ from torchvision import transforms as T
16
+
17
+ from .common import Heatmap, ImageLabels, LabelData, pil_make_grid
18
+
19
+ # working dir, either file parent dir or cwd if interactive
20
+ work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve()
21
+ temp_dir = work_dir.joinpath("temp")
22
+ temp_dir.mkdir(exist_ok=True, parents=True)
23
+
24
+ # model cache
25
+ model_cache: dict[str, VisionTransformer] = {}
26
+ transform_cache: dict[str, T.Compose] = {}
27
+
28
+ # device to use
29
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+
32
+ class RGBtoBGR(nn.Module):
33
+ def forward(self, x: Tensor) -> Tensor:
34
+ if x.ndim == 4:
35
+ return x[:, [2, 1, 0], :, :]
36
+ return x[[2, 1, 0], :, :]
37
+
38
+
39
+ def model_device(model: nn.Module) -> torch.device:
40
+ return next(model.parameters()).device
41
+
42
+
43
+ def load_model(repo_id: str) -> VisionTransformer:
44
+ global model_cache
45
+
46
+ if model_cache.get(repo_id, None) is None:
47
+ # save model to cache
48
+ model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch_device)
49
+
50
+ return model_cache[repo_id]
51
+
52
+
53
+ def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]:
54
+ global transform_cache
55
+ global model_cache
56
+
57
+ if model_cache.get(repo_id, None) is None:
58
+ # save model to cache
59
+ model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval()
60
+ model = model_cache[repo_id]
61
+
62
+ if transform_cache.get(repo_id, None) is None:
63
+ transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
64
+ # hack in the RGBtoBGR transform, save to cache
65
+ transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()])
66
+ transform = transform_cache[repo_id]
67
+
68
+ return model, transform
69
+
70
+
71
+ def get_tags(
72
+ probs: Tensor,
73
+ labels: LabelData,
74
+ gen_threshold: float,
75
+ char_threshold: float,
76
+ ):
77
+ # Convert indices+probs to labels
78
+ probs = list(zip(labels.names, probs.numpy()))
79
+
80
+ # First 4 labels are actually ratings
81
+ rating_labels = dict([probs[i] for i in labels.rating])
82
+
83
+ # General labels, pick any where prediction confidence > threshold
84
+ gen_labels = [probs[i] for i in labels.general]
85
+ gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
86
+ gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
87
+
88
+ # Character labels, pick any where prediction confidence > threshold
89
+ char_labels = [probs[i] for i in labels.character]
90
+ char_labels = dict([x for x in char_labels if x[1] > char_threshold])
91
+ char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
92
+
93
+ # Combine general and character labels, sort by confidence
94
+ combined_names = [x for x in gen_labels]
95
+ combined_names.extend([x for x in char_labels])
96
+
97
+ # Convert to a string suitable for use as a training caption
98
+ caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)")
99
+ booru = caption.replace("_", " ")
100
+
101
+ return caption, booru, rating_labels, char_labels, gen_labels
102
+
103
+
104
+ @torch.no_grad()
105
+ def render_heatmap(
106
+ image: Tensor,
107
+ gradients: Tensor,
108
+ image_feats: Tensor,
109
+ image_probs: Tensor,
110
+ image_labels: list[str],
111
+ cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71,
112
+ pos_embed_dim: int = 784,
113
+ image_size: tuple[int, int] = (448, 448),
114
+ font_args: dict = {
115
+ "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
116
+ "fontScale": 1,
117
+ "color": (255, 255, 255),
118
+ "thickness": 2,
119
+ "lineType": cv2.LINE_AA,
120
+ },
121
+ partial_rows: bool = True,
122
+ ) -> tuple[list[Heatmap], Image.Image]:
123
+ # hmap_dim = int(math.sqrt(pos_embed_dim))
124
+
125
+ image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
126
+ hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels)))
127
+ image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1)
128
+ image_hmaps = image_hmaps[..., -hmap_dim ** 2:]
129
+ image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim)
130
+ image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
131
+
132
+ image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
133
+ # normalize to 0-1
134
+ image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1)
135
+ # interpolate to input image size
136
+ image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1)
137
+
138
+ hmap_imgs: list[Heatmap] = []
139
+ for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()):
140
+ image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
141
+ hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3]
142
+
143
+ hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR)
144
+ hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0)
145
+ if tag is not None:
146
+ cv2.putText(hmap_image, tag, (10, 30), **font_args)
147
+ cv2.putText(hmap_image, f"{score:.3f}", org=(10, 60), **font_args)
148
+
149
+ hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB))
150
+ hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil))
151
+
152
+ hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True)
153
+ hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows)
154
+
155
+ return hmap_imgs, hmap_grid
156
+
157
+
158
+ def process_heatmap(
159
+ model: VisionTransformer,
160
+ image: Tensor,
161
+ labels: LabelData,
162
+ threshold: float = 0.5,
163
+ partial_rows: bool = True,
164
+ ) -> tuple[list[tuple[float, str, Image.Image]], Image.Image, ImageLabels]:
165
+ torch_device = model_device(model)
166
+
167
+ with torch.set_grad_enabled(True):
168
+ features = model.forward_features(image.to(torch_device))
169
+ probs = model.forward_head(features)
170
+ probs = F.sigmoid(probs).squeeze(0)
171
+
172
+ probs_mask = probs > threshold
173
+ heatmap_probs = probs[probs_mask]
174
+
175
+ label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1)
176
+ image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))]
177
+
178
+ eye = torch.eye(heatmap_probs.shape[0], device=torch_device)
179
+ grads = torch.autograd.grad(
180
+ outputs=heatmap_probs,
181
+ inputs=features,
182
+ grad_outputs=eye,
183
+ is_grads_batched=True,
184
+ retain_graph=True,
185
+ )
186
+ grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1)
187
+
188
+ with torch.set_grad_enabled(False):
189
+ hmap_imgs, hmap_grid = render_heatmap(
190
+ image=image,
191
+ gradients=grads,
192
+ image_feats=features,
193
+ image_probs=heatmap_probs,
194
+ image_labels=image_labels,
195
+ partial_rows=partial_rows,
196
+ )
197
+
198
+ caption, booru, ratings, character, general = get_tags(
199
+ probs=probs.cpu(),
200
+ labels=labels,
201
+ gen_threshold=threshold,
202
+ char_threshold=threshold,
203
+ )
204
+ labels = ImageLabels(caption, booru, ratings, general, character)
205
+
206
+ return hmap_imgs, hmap_grid, labels