refoundd commited on
Commit
2861775
·
verified ·
1 Parent(s): 28c7227

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +145 -193
handler.py CHANGED
@@ -6,218 +6,148 @@ from PIL import Image
6
  from huggingface_inference_toolkit.logging import logger
7
  from pymongo.mongo_client import MongoClient
8
  from diffusers.utils import load_image
9
- import huggingface_hub
10
  import numpy as np
11
- import onnxruntime as rt
12
  import pandas as pd
13
  import time
14
- import subprocess
15
-
16
- # List of commands to execute
17
- commands = [
18
- "wget https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2004-9.8.0_1.0-1_amd64.deb",
19
- "sudo dpkg -i cudnn-local-repo-ubuntu2004-9.8.0_1.0-1_amd64.deb",
20
- "sudo cp /var/cudnn-local-repo-ubuntu2004-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/",
21
- "sudo apt-get update",
22
- "sudo apt-get -y install cudnn",
23
- "sudo apt-get -y install cudnn-cuda-12"
24
- ]
25
-
26
- # Execute each command
27
- for command in commands:
28
- try:
29
- print(f"Running command: {command}")
30
- subprocess.run(command, shell=True, check=True)
31
- print(f"Command executed successfully: {command}")
32
- except subprocess.CalledProcessError as e:
33
- print(f"Error occurred while executing command: {e}")
34
-
35
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
36
-
37
- # Dataset v3 series of models:
38
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
39
-
40
- # Files to download from the repos
41
- MODEL_FILENAME = "model.onnx"
42
- LABEL_FILENAME = "selected_tags.csv"
43
-
44
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
45
- kaomojis = [
46
- "0_0",
47
- "(o)_(o)",
48
- "+_+",
49
- "+_-",
50
- "._.",
51
- "<o>_<o>",
52
- "<|>_<|>",
53
- "=_=",
54
- ">_<",
55
- "3_3",
56
- "6_9",
57
- ">_o",
58
- "@_@",
59
- "^_^",
60
- "o_o",
61
- "u_u",
62
- "x_x",
63
- "|_|",
64
- "||_||",
65
- ]
66
- def load_labels(dataframe) -> list[str]:
67
- name_series = dataframe["name"]
68
- name_series = name_series.map(
69
- lambda x: x.replace("_", " ") if x not in kaomojis else x
70
- )
71
- tag_names = name_series.tolist()
72
-
73
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
74
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
75
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
76
- return tag_names, rating_indexes, general_indexes, character_indexes
77
-
78
-
79
- def mcut_threshold(probs):
80
- """
81
- Maximum Cut Thresholding (MCut)
82
- Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
83
- for Multi-label Classification. In 11th International Symposium, IDA 2012
84
- (pp. 172-183).
85
- """
86
- sorted_probs = probs[probs.argsort()[::-1]]
87
- difs = sorted_probs[:-1] - sorted_probs[1:]
88
- t = difs.argmax()
89
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
90
- return thresh
91
-
92
-
93
- class Predictor:
94
- def __init__(self):
95
- self.model_target_size = None
96
- self.last_loaded_repo = None
97
-
98
- def download_model(self, model_repo):
99
- csv_path = huggingface_hub.hf_hub_download(
100
- model_repo,
101
- LABEL_FILENAME,
102
- use_auth_token=HF_TOKEN,
103
- )
104
- model_path = huggingface_hub.hf_hub_download(
105
- model_repo,
106
- MODEL_FILENAME,
107
- use_auth_token=HF_TOKEN,
108
- )
109
- return csv_path, model_path
110
-
111
- def load_model(self, model_repo):
112
- if model_repo == self.last_loaded_repo:
113
- return
114
 
115
- csv_path, model_path = self.download_model(model_repo)
116
-
117
- tags_df = pd.read_csv(csv_path)
118
- sep_tags = load_labels(tags_df)
119
-
120
- self.tag_names = sep_tags[0]
121
- self.rating_indexes = sep_tags[1]
122
- self.general_indexes = sep_tags[2]
123
- self.character_indexes = sep_tags[3]
124
- model = rt.InferenceSession(model_path,providers=['CUDAExecutionProvider','CPUExecutionProvider'])
125
- _, height, width, _ = model.get_inputs()[0].shape
126
- self.model_target_size = height
127
-
128
- self.last_loaded_repo = model_repo
129
- self.model = model
130
-
131
- def prepare_image(self, image):
132
- target_size = self.model_target_size
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
135
  canvas.alpha_composite(image)
136
  image = canvas.convert("RGB")
 
137
 
138
- # Pad image to square
139
- image_shape = image.size
140
- max_dim = max(image_shape)
141
- pad_left = (max_dim - image_shape[0]) // 2
142
- pad_top = (max_dim - image_shape[1]) // 2
143
 
144
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
145
- padded_image.paste(image, (pad_left, pad_top))
 
 
 
 
 
 
146
 
147
- # Resize
148
- if max_dim != target_size:
149
- padded_image = padded_image.resize(
150
- (target_size, target_size),
151
- Image.BICUBIC,
152
- )
153
 
154
- # Convert to numpy array
155
- image_array = np.asarray(padded_image, dtype=np.float32)
 
 
 
 
156
 
157
- # Convert PIL-native RGB to BGR
158
- image_array = image_array[:, :, ::-1]
159
 
160
- return np.expand_dims(image_array, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- def predict(
163
- self,
164
- image,
165
- model_repo,
166
- general_thresh,
167
- general_mcut_enabled,
168
- character_thresh,
169
- character_mcut_enabled,
170
- ):
171
- self.load_model(model_repo)
172
 
173
- image = self.prepare_image(image)
174
 
175
- input_name = self.model.get_inputs()[0].name
176
- label_name = self.model.get_outputs()[0].name
177
- preds = self.model.run([label_name], {input_name: image})[0]
 
 
 
 
 
178
 
179
- labels = list(zip(self.tag_names, preds[0].astype(float)))
 
180
 
181
- # First 4 labels are actually ratings: pick one with argmax
182
- ratings_names = [labels[i] for i in self.rating_indexes]
183
- rating = dict(ratings_names)
 
184
 
185
- # Then we have general tags: pick any where prediction confidence > threshold
186
- general_names = [labels[i] for i in self.general_indexes]
 
 
187
 
188
- if general_mcut_enabled:
189
- general_probs = np.array([x[1] for x in general_names])
190
- general_thresh = mcut_threshold(general_probs)
191
 
192
- general_res = [x for x in general_names if x[1] > general_thresh]
193
- general_res = dict(general_res)
 
194
 
195
- # Everything else is characters: pick any where prediction confidence > threshold
196
- character_names = [labels[i] for i in self.character_indexes]
197
 
198
- if character_mcut_enabled:
199
- character_probs = np.array([x[1] for x in character_names])
200
- character_thresh = mcut_threshold(character_probs)
201
- character_thresh = max(0.15, character_thresh)
202
 
203
- character_res = [x for x in character_names if x[1] > character_thresh]
204
- character_res = dict(character_res)
 
 
 
 
205
 
206
- sorted_general_strings = sorted(
207
- general_res.items(),
208
- key=lambda x: x[1],
209
- reverse=True,
210
- )
211
- sorted_general_strings = [x[0] for x in sorted_general_strings]
212
- sorted_general_strings = (
213
- ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)")
214
- )
215
- return {**rating, **character_res, **general_res}
216
  class EndpointHandler:
217
  def __init__(self, path=""):
218
- self.predictor = Predictor()
219
- self.model_repo = VIT_LARGE_MODEL_DSV3_REPO
220
- uri = os.environ.get("MongoDB", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  self.client = MongoClient(uri)
222
 
223
  self.db = self.client['nomorecopyright']
@@ -244,18 +174,40 @@ class EndpointHandler:
244
  start_time=time.time()
245
  for document in data:
246
  image=load_image(document.get('createdImage', 'https://nomorecopyright.com/default.jpg'))
247
- image = image.convert("RGBA")
248
- outputs = self.predictor.predict(
249
- image,
250
- self.model_repo,
251
- general_thresh=0.35,
252
- general_mcut_enabled=False,
253
- character_thresh=0.85,
254
- character_mcut_enabled=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  )
 
 
 
 
256
  saveQuery = {"_id": document.get('_id')}
257
  # Update operation to add keywords with confidence scores
258
- update_result = self.collection.update_one(saveQuery , {'$set': {'keywords': outputs}})
259
  end_time=time.time()
260
  print(f"Time taken: {end_time-start_time:.2f} seconds")
261
  return 'OK'
 
6
  from huggingface_inference_toolkit.logging import logger
7
  from pymongo.mongo_client import MongoClient
8
  from diffusers.utils import load_image
 
9
  import numpy as np
 
10
  import pandas as pd
11
  import time
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ from typing import Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ import numpy as np
17
+ import pandas as pd
18
+ import timm
19
+ import torch
20
+ from huggingface_hub import hf_hub_download
21
+ from huggingface_hub.utils import HfHubHTTPError
22
+ from PIL import Image
23
+ from simple_parsing import field
24
+ from timm.data import create_transform, resolve_data_config
25
+ from torch import Tensor, nn
26
+ from torch.nn import functional as F
 
 
 
 
 
 
 
27
 
28
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
29
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ MODEL_REPO_MAP = {
31
+ "vit": "SmilingWolf/wd-vit-large-tagger-v3",
32
+ }
33
+
34
+
35
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
36
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
37
+ if image.mode not in ["RGB", "RGBA"]:
38
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
39
+ # convert RGBA to RGB with white background
40
+ if image.mode == "RGBA":
41
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
42
  canvas.alpha_composite(image)
43
  image = canvas.convert("RGB")
44
+ return image
45
 
 
 
 
 
 
46
 
47
+ def pil_pad_square(image: Image.Image) -> Image.Image:
48
+ w, h = image.size
49
+ # get the largest dimension so we can pad to a square
50
+ px = max(image.size)
51
+ # pad to square with white background
52
+ canvas = Image.new("RGB", (px, px), (255, 255, 255))
53
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
54
+ return canvas
55
 
 
 
 
 
 
 
56
 
57
+ @dataclass
58
+ class LabelData:
59
+ names: list[str]
60
+ rating: list[np.int64]
61
+ general: list[np.int64]
62
+ character: list[np.int64]
63
 
 
 
64
 
65
+ def load_labels_hf(
66
+ repo_id: str,
67
+ revision: Optional[str] = None,
68
+ token: Optional[str] = None,
69
+ ) -> LabelData:
70
+ try:
71
+ csv_path = hf_hub_download(
72
+ repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
73
+ )
74
+ csv_path = Path(csv_path).resolve()
75
+ except HfHubHTTPError as e:
76
+ raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
77
+
78
+ df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
79
+ tag_data = LabelData(
80
+ names=df["name"].tolist(),
81
+ rating=list(np.where(df["category"] == 9)[0]),
82
+ general=list(np.where(df["category"] == 0)[0]),
83
+ character=list(np.where(df["category"] == 4)[0]),
84
+ )
85
 
86
+ return tag_data
 
 
 
 
 
 
 
 
 
87
 
 
88
 
89
+ def get_tags(
90
+ probs: Tensor,
91
+ labels: LabelData,
92
+ gen_threshold: float,
93
+ char_threshold: float,
94
+ ):
95
+ # Convert indices+probs to labels
96
+ probs = list(zip(labels.names, probs.numpy()))
97
 
98
+ # First 4 labels are actually ratings
99
+ rating_labels = dict([probs[i] for i in labels.rating])
100
 
101
+ # General labels, pick any where prediction confidence > threshold
102
+ gen_labels = [probs[i] for i in labels.general]
103
+ gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
104
+ gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
105
 
106
+ # Character labels, pick any where prediction confidence > threshold
107
+ char_labels = [probs[i] for i in labels.character]
108
+ char_labels = dict([x for x in char_labels if x[1] > char_threshold])
109
+ char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
110
 
111
+ # Combine general and character labels, sort by confidence
112
+ combined_names = [x for x in gen_labels]
113
+ combined_names.extend([x for x in char_labels])
114
 
115
+ # Convert to a string suitable for use as a training caption
116
+ caption = ", ".join(combined_names)
117
+ taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
118
 
119
+ return caption, taglist, rating_labels, char_labels, gen_labels
 
120
 
 
 
 
 
121
 
122
+ @dataclass
123
+ class ScriptOptions:
124
+ image_file: Path = field(positional=True)
125
+ model: str = field(default="vit")
126
+ gen_threshold: float = field(default=0.35)
127
+ char_threshold: float = field(default=0.75)
128
 
 
 
 
 
 
 
 
 
 
 
129
  class EndpointHandler:
130
  def __init__(self, path=""):
131
+ self.opts = ScriptOptions
132
+ repo_id = MODEL_REPO_MAP.get(self.opts.model)
133
+
134
+ print(f"Loading model '{self.opts.model}' from '{repo_id}'...")
135
+ self.model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
136
+ state_dict = timm.models.load_state_dict_from_hf(repo_id)
137
+ self.model.load_state_dict(state_dict)
138
+
139
+ print("Loading tag list...")
140
+ self.labels: LabelData = load_labels_hf(repo_id=repo_id)
141
+
142
+ print("Creating data transform...")
143
+ self.transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
144
+
145
+ with torch.inference_mode():
146
+ # move model to GPU, if available
147
+ if torch_device.type != "cpu":
148
+ self.model = self.model.to(torch_device)
149
+
150
+ uri = os.environ.get("MongoDB", "mongodb+srv://jamie:qJiuKQpqhXMHGb74@cluster0.i5ujz.mongodb.net/")
151
  self.client = MongoClient(uri)
152
 
153
  self.db = self.client['nomorecopyright']
 
174
  start_time=time.time()
175
  for document in data:
176
  image=load_image(document.get('createdImage', 'https://nomorecopyright.com/default.jpg'))
177
+ print("Loading image and preprocessing...")
178
+ # get image
179
+ # ensure image is RGB
180
+ img_input = pil_ensure_rgb(image)
181
+ # pad to square with white background
182
+ img_input = pil_pad_square(img_input)
183
+ # run the model's input transform to convert to tensor and rescale
184
+ inputs: Tensor = self.transform(img_input).unsqueeze(0)
185
+ # NCHW image RGB to BGR
186
+ inputs = inputs[:, [2, 1, 0]]
187
+ inputs = inputs.to(torch_device)
188
+ print("Running inference...")
189
+ outputs = self.model.forward(inputs)
190
+ # apply the final activation function (timm doesn't support doing this internally)
191
+ outputs = F.sigmoid(outputs)
192
+ # move inputs, outputs, and model back to to cpu if we were on GPU
193
+ if torch_device.type != "cpu":
194
+ inputs = inputs.to("cpu")
195
+ outputs = outputs.to("cpu")
196
+
197
+ print("Processing results...")
198
+ caption, taglist, ratings, character, general = get_tags(
199
+ probs=outputs.squeeze(0),
200
+ labels=self.labels,
201
+ gen_threshold=self.opts.gen_threshold,
202
+ char_threshold=self.opts.char_threshold,
203
  )
204
+
205
+ results={**ratings, **character, **general}
206
+ print(results)
207
+
208
  saveQuery = {"_id": document.get('_id')}
209
  # Update operation to add keywords with confidence scores
210
+ update_result = self.collection.update_one(saveQuery , {'$set': {'keywords': results}})
211
  end_time=time.time()
212
  print(f"Time taken: {end_time-start_time:.2f} seconds")
213
  return 'OK'