IdlecloudX commited on
Commit
12cfca5
·
verified ·
1 Parent(s): f40029f

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +279 -0
handler.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import requests
11
+ import timm
12
+ import torch
13
+ import torchvision.transforms as transforms
14
+ from PIL import Image
15
+
16
+
17
+ class TaggingHead(torch.nn.Module):
18
+ def __init__(self, input_dim, num_classes):
19
+ super().__init__()
20
+ self.input_dim = input_dim
21
+ self.num_classes = num_classes
22
+ self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))
23
+
24
+ def forward(self, x):
25
+ logits = self.head(x)
26
+ probs = torch.nn.functional.sigmoid(logits)
27
+ return probs
28
+
29
+
30
+ def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]:
31
+ with tags_file.open("r", encoding="utf-8") as f:
32
+ tag_info = json.load(f)
33
+ tag_map = tag_info["tag_map"]
34
+ tag_split = tag_info["tag_split"]
35
+ gen_tag_count = tag_split["gen_tag_count"]
36
+ character_tag_count = tag_split["character_tag_count"]
37
+ return tag_map, gen_tag_count, character_tag_count
38
+
39
+
40
+ def get_character_ip_mapping(mapping_file: Path):
41
+ with mapping_file.open("r", encoding="utf-8") as f:
42
+ mapping = json.load(f)
43
+ return mapping
44
+
45
+
46
+ def get_encoder():
47
+ base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3"
48
+ encoder = timm.create_model(base_model_repo, pretrained=False)
49
+ encoder.reset_classifier(0)
50
+ return encoder
51
+
52
+
53
+ def get_decoder():
54
+ decoder = TaggingHead(1024, 13461)
55
+ return decoder
56
+
57
+
58
+ def get_model():
59
+ encoder = get_encoder()
60
+ decoder = get_decoder()
61
+ model = torch.nn.Sequential(encoder, decoder)
62
+ return model
63
+
64
+
65
+ def load_model(weights_file, device):
66
+ model = get_model()
67
+ states_dict = torch.load(weights_file, map_location=device, weights_only=True)
68
+ model.load_state_dict(states_dict)
69
+ model.to(device)
70
+ model.eval()
71
+ return model
72
+
73
+
74
+ def pure_pil_alpha_to_color_v2(
75
+ image: Image.Image, color: tuple[int, int, int] = (255, 255, 255)
76
+ ) -> Image.Image:
77
+ """
78
+ Convert a PIL image with an alpha channel to a RGB image.
79
+ This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel.
80
+ This function will convert the image to a RGB image, and fill the alpha channel with the given color.
81
+ The alpha channel is the 4th channel of the image.
82
+ """
83
+ image.load() # needed for split()
84
+ background = Image.new("RGB", image.size, color)
85
+ background.paste(image, mask=image.split()[3]) # 3 is the alpha channel
86
+ return background
87
+
88
+
89
+ def pil_to_rgb(image: Image.Image) -> Image.Image:
90
+ if image.mode == "RGBA":
91
+ image = pure_pil_alpha_to_color_v2(image)
92
+ elif image.mode == "P":
93
+ image = pure_pil_alpha_to_color_v2(image.convert("RGBA"))
94
+ else:
95
+ image = image.convert("RGB")
96
+ return image
97
+
98
+
99
+ class EndpointHandler:
100
+ def __init__(self, path: str):
101
+ repo_path = Path(path)
102
+ assert repo_path.is_dir(), f"Model directory not found: {repo_path}"
103
+ weights_file = repo_path / "model_v0.9.pth"
104
+ tags_file = repo_path / "tags_v0.9_13k.json"
105
+ mapping_file = repo_path / "char_ip_map.json"
106
+ if not weights_file.exists():
107
+ raise FileNotFoundError(f"Model file not found: {weights_file}")
108
+ if not tags_file.exists():
109
+ raise FileNotFoundError(f"Tags file not found: {tags_file}")
110
+ if not mapping_file.exists():
111
+ raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
112
+
113
+ # Robust device selection: prefer CPU unless CUDA is truly usable
114
+ force_cpu = os.environ.get("FORCE_CPU", "0") in {"1", "true", "TRUE", "yes", "on"}
115
+ if not force_cpu and torch.cuda.is_available():
116
+ try:
117
+ # Probe that CUDA can actually be used (driver present)
118
+ torch.zeros(1).to("cuda")
119
+ self.device = "cuda"
120
+ except Exception:
121
+ self.device = "cpu"
122
+ else:
123
+ self.device = "cpu"
124
+ self.model = load_model(str(weights_file), self.device)
125
+ self.transform = transforms.Compose(
126
+ [
127
+ transforms.Resize((448, 448)),
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
130
+ ]
131
+ )
132
+ self.fetch_image_timeout = 5.0
133
+ self.default_general_threshold = 0.3
134
+ self.default_character_threshold = 0.85
135
+
136
+ tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file)
137
+
138
+ # Invert the tag_map for efficient index-to-tag lookups
139
+ self.index_to_tag_map = {v: k for k, v in tag_map.items()}
140
+
141
+ self.character_ip_mapping = get_character_ip_mapping(mapping_file)
142
+
143
+ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
144
+ inputs = data.pop("inputs", data)
145
+
146
+ fetch_start_time = time.time()
147
+ if isinstance(inputs, Image.Image):
148
+ image = inputs
149
+ elif image_url := inputs.pop("url", None):
150
+ with requests.get(
151
+ image_url, stream=True, timeout=self.fetch_image_timeout
152
+ ) as res:
153
+ res.raise_for_status()
154
+ image = Image.open(res.raw)
155
+ elif image_base64_encoded := inputs.pop("image", None):
156
+ image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded)))
157
+ else:
158
+ raise ValueError(f"No image or url provided: {data}")
159
+ # remove alpha channel if it exists
160
+ image = pil_to_rgb(image)
161
+ fetch_time = time.time() - fetch_start_time
162
+
163
+ parameters = data.pop("parameters", {})
164
+ general_threshold = parameters.pop(
165
+ "general_threshold", self.default_general_threshold
166
+ )
167
+ character_threshold = parameters.pop(
168
+ "character_threshold", self.default_character_threshold
169
+ )
170
+ # Optional behavior controls
171
+ mode = parameters.pop("mode", "threshold") # "threshold" | "topk"
172
+ include_scores = bool(parameters.pop("include_scores", False))
173
+ topk_general = int(parameters.pop("topk_general", 25))
174
+ topk_character = int(parameters.pop("topk_character", 10))
175
+
176
+ inference_start_time = time.time()
177
+ with torch.inference_mode():
178
+ # Preprocess image on CPU
179
+ image_tensor = self.transform(image).unsqueeze(0)
180
+ # Pin memory and use non_blocking transfer only when using CUDA
181
+ if self.device == "cuda":
182
+ image_tensor = image_tensor.pin_memory().to(self.device, non_blocking=True)
183
+ else:
184
+ image_tensor = image_tensor.to(self.device)
185
+
186
+ # Run model on GPU
187
+ probs = self.model(image_tensor)[0] # Get probs for the single image
188
+
189
+ if mode == "topk":
190
+ # Select top-k by category, independent of thresholds
191
+ gen_slice = probs[: self.gen_tag_count]
192
+ char_slice = probs[self.gen_tag_count :]
193
+ k_gen = max(0, min(int(topk_general), self.gen_tag_count))
194
+ k_char = max(0, min(int(topk_character), self.character_tag_count))
195
+ gen_scores, gen_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long))
196
+ char_scores, char_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long))
197
+ if k_gen > 0:
198
+ gen_scores, gen_idx = torch.topk(gen_slice, k_gen)
199
+ if k_char > 0:
200
+ char_scores, char_idx = torch.topk(char_slice, k_char)
201
+ char_idx = char_idx + self.gen_tag_count
202
+
203
+ # Merge for unified post-processing
204
+ combined_indices = torch.cat((gen_idx, char_idx)).cpu()
205
+ combined_scores = torch.cat((gen_scores, char_scores)).cpu()
206
+ else:
207
+ # Perform thresholding directly on the GPU
208
+ general_mask = probs[: self.gen_tag_count] > general_threshold
209
+ character_mask = probs[self.gen_tag_count :] > character_threshold
210
+
211
+ # Get the indices of positive tags on the GPU
212
+ general_indices = general_mask.nonzero(as_tuple=True)[0]
213
+ character_indices = (
214
+ character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count
215
+ )
216
+
217
+ # Combine indices and move the small result tensor to the CPU
218
+ combined_indices = torch.cat((general_indices, character_indices)).cpu()
219
+ combined_scores = probs[combined_indices].detach().float().cpu()
220
+
221
+ inference_time = time.time() - inference_start_time
222
+
223
+ post_process_start_time = time.time()
224
+
225
+ cur_gen_tags = []
226
+ cur_char_tags = []
227
+ gen_scores_out: dict[str, float] = {}
228
+ char_scores_out: dict[str, float] = {}
229
+
230
+ # Use the efficient pre-computed map for lookups
231
+ for pos, i in enumerate(combined_indices):
232
+ idx = int(i.item())
233
+ tag = self.index_to_tag_map[idx]
234
+ if idx < self.gen_tag_count:
235
+ cur_gen_tags.append(tag)
236
+ if include_scores:
237
+ score = float(combined_scores[pos].item())
238
+ gen_scores_out[tag] = score
239
+ else:
240
+ cur_char_tags.append(tag)
241
+ if include_scores:
242
+ score = float(combined_scores[pos].item())
243
+ char_scores_out[tag] = score
244
+
245
+ ip_tags = []
246
+ for tag in cur_char_tags:
247
+ if tag in self.character_ip_mapping:
248
+ ip_tags.extend(self.character_ip_mapping[tag])
249
+ ip_tags = sorted(set(ip_tags))
250
+ post_process_time = time.time() - post_process_start_time
251
+
252
+ logging.info(
253
+ f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s"
254
+ )
255
+
256
+ out: dict[str, Any] = {
257
+ "feature": cur_gen_tags,
258
+ "character": cur_char_tags,
259
+ "ip": ip_tags,
260
+ "_timings": {
261
+ "fetch_s": round(fetch_time, 4),
262
+ "inference_s": round(inference_time, 4),
263
+ "post_process_s": round(post_process_time, 4),
264
+ "total_s": round(fetch_time + inference_time + post_process_time, 4),
265
+ },
266
+ "_params": {
267
+ "mode": mode,
268
+ "general_threshold": general_threshold,
269
+ "character_threshold": character_threshold,
270
+ "topk_general": topk_general,
271
+ "topk_character": topk_character,
272
+ },
273
+ }
274
+
275
+ if include_scores:
276
+ out["feature_scores"] = gen_scores_out
277
+ out["character_scores"] = char_scores_out
278
+
279
+ return out