forced lora
Browse files- handler.py +58 -11
handler.py
CHANGED
|
@@ -6,6 +6,7 @@ from io import BytesIO
|
|
| 6 |
from pprint import pprint
|
| 7 |
from typing import Any, Dict, List
|
| 8 |
import os
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Union
|
| 11 |
from concurrent.futures import ThreadPoolExecutor
|
|
@@ -87,6 +88,7 @@ class EndpointHandler:
|
|
| 87 |
self.inference_progress = {} # Dictionary to store progress of each request
|
| 88 |
self.inference_images = {} # Dictionary to store latest image of each request
|
| 89 |
self.total_steps = {}
|
|
|
|
| 90 |
self.inference_in_progress = False
|
| 91 |
|
| 92 |
self.executor = ThreadPoolExecutor(
|
|
@@ -131,6 +133,18 @@ class EndpointHandler:
|
|
| 131 |
self.pipe.enable_attention_slicing()
|
| 132 |
# may need a requirement in the root with xformer
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
def load_lora(self, pipeline, lora_path, lora_weight=0.5):
|
| 135 |
state_dict = load_file(lora_path)
|
| 136 |
LORA_PREFIX_UNET = "lora_unet"
|
|
@@ -218,10 +232,33 @@ class EndpointHandler:
|
|
| 218 |
"""Load Loras models, can lead to marvelous creations"""
|
| 219 |
for model_name, weight in selections:
|
| 220 |
lora_path = EndpointHandler.LORA_PATHS[model_name]
|
| 221 |
-
self.pipe
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
def clean_request_data(self, request_id: str):
|
| 227 |
"""Clean up the data related to a specific request ID."""
|
|
@@ -235,6 +272,9 @@ class EndpointHandler:
|
|
| 235 |
# Remove the request ID from the total_steps dictionary
|
| 236 |
self.total_steps.pop(request_id, None)
|
| 237 |
|
|
|
|
|
|
|
|
|
|
| 238 |
# Set inference to False
|
| 239 |
self.inference_in_progress = False
|
| 240 |
|
|
@@ -349,17 +389,18 @@ class EndpointHandler:
|
|
| 349 |
self.total_steps[request_id] = num_inference_steps
|
| 350 |
|
| 351 |
# USe this to add automatically some negative prompts
|
| 352 |
-
forced_negative = (
|
| 353 |
-
negative_prompt
|
| 354 |
-
+ """, easynegative, badhandv4, bad-artist-anime, negfeetv2, ng_deepnegative_v1_75t, bad-hands-5, """
|
| 355 |
-
)
|
| 356 |
|
| 357 |
# Set the generator seed if provided
|
| 358 |
generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None
|
| 359 |
|
| 360 |
# Load the provided Lora models
|
|
|
|
| 361 |
# if loras_model:
|
| 362 |
-
# self.
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
try:
|
| 365 |
# 2. Process
|
|
@@ -376,8 +417,8 @@ class EndpointHandler:
|
|
| 376 |
callback=lambda step, timestep, latents: self.progress_callback(
|
| 377 |
step, timestep, latents, request_id, "progress"
|
| 378 |
),
|
| 379 |
-
callback_steps=5,
|
| 380 |
-
#
|
| 381 |
).images[0]
|
| 382 |
|
| 383 |
# print(image)
|
|
@@ -405,6 +446,11 @@ class EndpointHandler:
|
|
| 405 |
return {"flag": "error", "message": "Missing request_id."}
|
| 406 |
|
| 407 |
if action == "check_progress":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
return self.check_progress(request_id)
|
| 409 |
|
| 410 |
elif action == "inference":
|
|
@@ -420,6 +466,7 @@ class EndpointHandler:
|
|
| 420 |
self.inference_in_progress = True
|
| 421 |
self.inference_progress[request_id] = 0
|
| 422 |
self.inference_images[request_id] = None
|
|
|
|
| 423 |
|
| 424 |
self.executor.submit(self.start_inference, data)
|
| 425 |
|
|
|
|
| 6 |
from pprint import pprint
|
| 7 |
from typing import Any, Dict, List
|
| 8 |
import os
|
| 9 |
+
import re
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Union
|
| 12 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
| 88 |
self.inference_progress = {} # Dictionary to store progress of each request
|
| 89 |
self.inference_images = {} # Dictionary to store latest image of each request
|
| 90 |
self.total_steps = {}
|
| 91 |
+
self.active_request_ids = set()
|
| 92 |
self.inference_in_progress = False
|
| 93 |
|
| 94 |
self.executor = ThreadPoolExecutor(
|
|
|
|
| 133 |
self.pipe.enable_attention_slicing()
|
| 134 |
# may need a requirement in the root with xformer
|
| 135 |
|
| 136 |
+
# Load loras one time only
|
| 137 |
+
# Must be replaced once we will know how to hot load/unload
|
| 138 |
+
# it use the own made load_lora function
|
| 139 |
+
self.load_selected_loras(
|
| 140 |
+
[
|
| 141 |
+
["polyhedron_new_skin_v1.1", 0.2],
|
| 142 |
+
["detailed_eye-10", 0.2],
|
| 143 |
+
["add_detail", 0.3],
|
| 144 |
+
["MuscleGirl_v1", 0.2],
|
| 145 |
+
]
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
def load_lora(self, pipeline, lora_path, lora_weight=0.5):
|
| 149 |
state_dict = load_file(lora_path)
|
| 150 |
LORA_PREFIX_UNET = "lora_unet"
|
|
|
|
| 232 |
"""Load Loras models, can lead to marvelous creations"""
|
| 233 |
for model_name, weight in selections:
|
| 234 |
lora_path = EndpointHandler.LORA_PATHS[model_name]
|
| 235 |
+
# self.pipe.load_lora_weights(lora_path)
|
| 236 |
+
self.load_lora(self.pipe, lora_path, weight)
|
| 237 |
+
|
| 238 |
+
def clean_negative_prompt(self, negative_prompt):
|
| 239 |
+
"""Clean negative prompt to remove already used negative prompt handlers"""
|
| 240 |
+
|
| 241 |
+
# negative_prompt = (
|
| 242 |
+
# negative_prompt
|
| 243 |
+
# + """, easynegative, badhandv4, bad-artist-anime, negfeetv2, ng_deepnegative_v1_75t, bad-hands-5, """
|
| 244 |
+
# )
|
| 245 |
+
|
| 246 |
+
tokens = [item["token"] for item in self.TEXTUAL_INVERSION]
|
| 247 |
+
|
| 248 |
+
# Retirer tous les tokens de negative_prompt s'ils existent déjà
|
| 249 |
+
for token in tokens:
|
| 250 |
+
# Utiliser une expression régulière pour un remplacement insensible à la casse
|
| 251 |
+
negative_prompt = re.sub(
|
| 252 |
+
r"\b" + re.escape(token) + r"\b",
|
| 253 |
+
"",
|
| 254 |
+
negative_prompt,
|
| 255 |
+
flags=re.IGNORECASE,
|
| 256 |
+
).strip()
|
| 257 |
+
|
| 258 |
+
# Ajouter tous les tokens à la fin de negative_prompt
|
| 259 |
+
negative_prompt += " " + " ".join(tokens)
|
| 260 |
+
|
| 261 |
+
return negative_prompt
|
| 262 |
|
| 263 |
def clean_request_data(self, request_id: str):
|
| 264 |
"""Clean up the data related to a specific request ID."""
|
|
|
|
| 272 |
# Remove the request ID from the total_steps dictionary
|
| 273 |
self.total_steps.pop(request_id, None)
|
| 274 |
|
| 275 |
+
# Delete request id
|
| 276 |
+
self.active_request_ids.discard(request_id)
|
| 277 |
+
|
| 278 |
# Set inference to False
|
| 279 |
self.inference_in_progress = False
|
| 280 |
|
|
|
|
| 389 |
self.total_steps[request_id] = num_inference_steps
|
| 390 |
|
| 391 |
# USe this to add automatically some negative prompts
|
| 392 |
+
forced_negative = self.clean_negative_prompt(negative_prompt)
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
# Set the generator seed if provided
|
| 395 |
generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None
|
| 396 |
|
| 397 |
# Load the provided Lora models
|
| 398 |
+
# self.pipe.unload_lora_weights() # Unload models to avoid lora staking
|
| 399 |
# if loras_model:
|
| 400 |
+
# self.load_selected_loras(loras_model)
|
| 401 |
+
|
| 402 |
+
# set scale of loras, for now take only first scale of the loaded lora and apply to all until we find the way to apply specified scale
|
| 403 |
+
# scale = {"scale": loras_model[0][1]} if loras_model else None
|
| 404 |
|
| 405 |
try:
|
| 406 |
# 2. Process
|
|
|
|
| 417 |
callback=lambda step, timestep, latents: self.progress_callback(
|
| 418 |
step, timestep, latents, request_id, "progress"
|
| 419 |
),
|
| 420 |
+
callback_steps=5,
|
| 421 |
+
# cross_attention_kwargs={"scale": 0.2},
|
| 422 |
).images[0]
|
| 423 |
|
| 424 |
# print(image)
|
|
|
|
| 446 |
return {"flag": "error", "message": "Missing request_id."}
|
| 447 |
|
| 448 |
if action == "check_progress":
|
| 449 |
+
if request_id not in self.active_request_ids:
|
| 450 |
+
return {
|
| 451 |
+
"flag": "error",
|
| 452 |
+
"message": "Request id doesn't match any active request.",
|
| 453 |
+
}
|
| 454 |
return self.check_progress(request_id)
|
| 455 |
|
| 456 |
elif action == "inference":
|
|
|
|
| 466 |
self.inference_in_progress = True
|
| 467 |
self.inference_progress[request_id] = 0
|
| 468 |
self.inference_images[request_id] = None
|
| 469 |
+
self.active_request_ids.add(request_id)
|
| 470 |
|
| 471 |
self.executor.submit(self.start_inference, data)
|
| 472 |
|