refoundd commited on
Commit
e6021fb
·
verified ·
1 Parent(s): 85fed20

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +214 -49
handler.py CHANGED
@@ -1,39 +1,208 @@
1
  import os
2
  from typing import Any, Dict
3
  from PIL import Image
4
- import torch
5
- from diffusers import FluxPipeline
6
  from huggingface_inference_toolkit.logging import logger
7
- from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
 
 
 
 
 
8
  import time
9
- from para_attn.context_parallel import init_context_parallel_mesh
10
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
11
- from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
12
 
13
- class EndpointHandler:
14
- def __init__(self, path=""):
15
- self.pipe = FluxPipeline.from_pretrained(
16
- "NoMoreCopyrightOrg/flux-dev",
17
- torch_dtype=torch.bfloat16,
18
- ).to("cuda")
19
- mesh = init_context_parallel_mesh(
20
- self.pipe.device.type,
21
- max_ring_dim_size=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
- parallelize_pipe(
24
- self.pipe,
25
- mesh=mesh,
 
26
  )
27
- parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
28
- apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
29
- torch._inductor.config.reorder_for_compute_comm_overlap = True
30
- self.pipe.transformer = torch.compile(
31
- self.pipe.transformer, mode="max-autotune-no-cudagraphs",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
- self.pipe.vae = torch.compile(
34
- self.pipe.vae, mode="max-autotune-no-cudagraphs",
 
35
  )
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
37
  def __call__(self, data: Dict[str, Any]) -> str:
38
  logger.info(f"Received incoming request with {data=}")
39
 
@@ -46,27 +215,23 @@ class EndpointHandler:
46
  "Provided input body must contain either the key `inputs` or `prompt` with the"
47
  " prompt to use for the image generation, and it needs to be a non-empty string."
48
  )
49
-
50
- parameters = data.pop("parameters", {})
51
-
52
- num_inference_steps = parameters.get("num_inference_steps", 28)
53
- width = parameters.get("width", 1024)
54
- height = parameters.get("height", 1024)
55
- guidance_scale = parameters.get("guidance_scale", 3.5)
56
-
57
- # seed generator (seed cannot be provided as is but via a generator)
58
- seed = parameters.get("seed", 0)
59
- generator = torch.manual_seed(seed)
60
- start_time = time.time()
61
- result = self.pipe( # type: ignore
62
- prompt,
63
- height=height,
64
- width=width,
65
- guidance_scale=guidance_scale,
66
- num_inference_steps=num_inference_steps,
67
- generator=generator,
68
- ).images[0]
69
- end_time = time.time()
70
- time_taken = end_time - start_time
71
- print(f"Time taken: {time_taken:.2f} seconds")
72
- return result
 
1
  import os
2
  from typing import Any, Dict
3
  from PIL import Image
 
 
4
  from huggingface_inference_toolkit.logging import logger
5
+ from pymongo.mongo_client import MongoClient
6
+ from diffusers.utils import load_image
7
+ import huggingface_hub
8
+ import numpy as np
9
+ import onnxruntime as rt
10
+ import pandas as pd
11
  import time
12
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
 
 
13
 
14
+ # Dataset v3 series of models:
15
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
16
+
17
+ # Files to download from the repos
18
+ MODEL_FILENAME = "model.onnx"
19
+ LABEL_FILENAME = "selected_tags.csv"
20
+
21
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
22
+ kaomojis = [
23
+ "0_0",
24
+ "(o)_(o)",
25
+ "+_+",
26
+ "+_-",
27
+ "._.",
28
+ "<o>_<o>",
29
+ "<|>_<|>",
30
+ "=_=",
31
+ ">_<",
32
+ "3_3",
33
+ "6_9",
34
+ ">_o",
35
+ "@_@",
36
+ "^_^",
37
+ "o_o",
38
+ "u_u",
39
+ "x_x",
40
+ "|_|",
41
+ "||_||",
42
+ ]
43
+ def load_labels(dataframe) -> list[str]:
44
+ name_series = dataframe["name"]
45
+ name_series = name_series.map(
46
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
47
+ )
48
+ tag_names = name_series.tolist()
49
+
50
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
51
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
52
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
53
+ return tag_names, rating_indexes, general_indexes, character_indexes
54
+
55
+
56
+ def mcut_threshold(probs):
57
+ """
58
+ Maximum Cut Thresholding (MCut)
59
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
60
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
61
+ (pp. 172-183).
62
+ """
63
+ sorted_probs = probs[probs.argsort()[::-1]]
64
+ difs = sorted_probs[:-1] - sorted_probs[1:]
65
+ t = difs.argmax()
66
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
67
+ return thresh
68
+
69
+
70
+ class Predictor:
71
+ def __init__(self):
72
+ self.model_target_size = None
73
+ self.last_loaded_repo = None
74
+
75
+ def download_model(self, model_repo):
76
+ csv_path = huggingface_hub.hf_hub_download(
77
+ model_repo,
78
+ LABEL_FILENAME,
79
+ use_auth_token=HF_TOKEN,
80
  )
81
+ model_path = huggingface_hub.hf_hub_download(
82
+ model_repo,
83
+ MODEL_FILENAME,
84
+ use_auth_token=HF_TOKEN,
85
  )
86
+ return csv_path, model_path
87
+
88
+ def load_model(self, model_repo):
89
+ if model_repo == self.last_loaded_repo:
90
+ return
91
+
92
+ csv_path, model_path = self.download_model(model_repo)
93
+
94
+ tags_df = pd.read_csv(csv_path)
95
+ sep_tags = load_labels(tags_df)
96
+
97
+ self.tag_names = sep_tags[0]
98
+ self.rating_indexes = sep_tags[1]
99
+ self.general_indexes = sep_tags[2]
100
+ self.character_indexes = sep_tags[3]
101
+ model = rt.InferenceSession(model_path,providers=['CUDAExecutionProvider','CPUExecutionProvider'])
102
+ _, height, width, _ = model.get_inputs()[0].shape
103
+ self.model_target_size = height
104
+
105
+ self.last_loaded_repo = model_repo
106
+ self.model = model
107
+
108
+ def prepare_image(self, image):
109
+ target_size = self.model_target_size
110
+
111
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
112
+ canvas.alpha_composite(image)
113
+ image = canvas.convert("RGB")
114
+
115
+ # Pad image to square
116
+ image_shape = image.size
117
+ max_dim = max(image_shape)
118
+ pad_left = (max_dim - image_shape[0]) // 2
119
+ pad_top = (max_dim - image_shape[1]) // 2
120
+
121
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
122
+ padded_image.paste(image, (pad_left, pad_top))
123
+
124
+ # Resize
125
+ if max_dim != target_size:
126
+ padded_image = padded_image.resize(
127
+ (target_size, target_size),
128
+ Image.BICUBIC,
129
+ )
130
+
131
+ # Convert to numpy array
132
+ image_array = np.asarray(padded_image, dtype=np.float32)
133
+
134
+ # Convert PIL-native RGB to BGR
135
+ image_array = image_array[:, :, ::-1]
136
+
137
+ return np.expand_dims(image_array, axis=0)
138
+
139
+ def predict(
140
+ self,
141
+ image,
142
+ model_repo,
143
+ general_thresh,
144
+ general_mcut_enabled,
145
+ character_thresh,
146
+ character_mcut_enabled,
147
+ ):
148
+ self.load_model(model_repo)
149
+
150
+ image = self.prepare_image(image)
151
+
152
+ input_name = self.model.get_inputs()[0].name
153
+ label_name = self.model.get_outputs()[0].name
154
+ preds = self.model.run([label_name], {input_name: image})[0]
155
+
156
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
157
+
158
+ # First 4 labels are actually ratings: pick one with argmax
159
+ ratings_names = [labels[i] for i in self.rating_indexes]
160
+ rating = dict(ratings_names)
161
+
162
+ # Then we have general tags: pick any where prediction confidence > threshold
163
+ general_names = [labels[i] for i in self.general_indexes]
164
+
165
+ if general_mcut_enabled:
166
+ general_probs = np.array([x[1] for x in general_names])
167
+ general_thresh = mcut_threshold(general_probs)
168
+
169
+ general_res = [x for x in general_names if x[1] > general_thresh]
170
+ general_res = dict(general_res)
171
+
172
+ # Everything else is characters: pick any where prediction confidence > threshold
173
+ character_names = [labels[i] for i in self.character_indexes]
174
+
175
+ if character_mcut_enabled:
176
+ character_probs = np.array([x[1] for x in character_names])
177
+ character_thresh = mcut_threshold(character_probs)
178
+ character_thresh = max(0.15, character_thresh)
179
+
180
+ character_res = [x for x in character_names if x[1] > character_thresh]
181
+ character_res = dict(character_res)
182
+
183
+ sorted_general_strings = sorted(
184
+ general_res.items(),
185
+ key=lambda x: x[1],
186
+ reverse=True,
187
  )
188
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
189
+ sorted_general_strings = (
190
+ ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)")
191
  )
192
+ return {**rating, **character_res, **general_res}
193
+ class EndpointHandler:
194
+ def __init__(self, path=""):
195
+ self.predictor = Predictor()
196
+ self.model_repo = VIT_LARGE_MODEL_DSV3_REPO
197
+ uri = "mongodb+srv://jamie:qJiuKQpqhXMHGb74@cluster0.i5ujz.mongodb.net/"
198
+ self.client = MongoClient(uri)
199
 
200
+ self.db = self.client['nomorecopyright']
201
+ self.collection = self.db['imagerequests']
202
+
203
+ self.query = {"requestTimestamp": {"$gt": "1742815635"}}
204
+ self.projection = {"_id": 0, "requestImage": 1}
205
+
206
  def __call__(self, data: Dict[str, Any]) -> str:
207
  logger.info(f"Received incoming request with {data=}")
208
 
 
215
  "Provided input body must contain either the key `inputs` or `prompt` with the"
216
  " prompt to use for the image generation, and it needs to be a non-empty string."
217
  )
218
+ start_index,limit_count=prompt.split(',')
219
+ data = list(self.collection.find().skip(start_index).limit(limit_count))
220
+ start_time=time.time()
221
+ for document in data:
222
+ image=load_image(document.get('requestImage', 'https://nomorecopyright.com/default.jpg'))
223
+ image = image.convert("RGBA")
224
+ outputs = self.predictor.predict(
225
+ image,
226
+ self.model_repo,
227
+ general_thresh=0.35,
228
+ general_mcut_enabled=False,
229
+ character_thresh=0.85,
230
+ character_mcut_enabled=False,
231
+ )
232
+ saveQuery = {"_id": document.get('_id')}
233
+ # Update operation to add keywords with confidence scores
234
+ update_result = self.collection.update_one(saveQuery , {'$set': {'keywords': outputs}})
235
+ end_time=time.time()
236
+ print(f"Time taken: {end_time-start_time:.2f} seconds")
237
+ return 'OK'