error CUDA out of memory. Tried to allocate 146.00 MiB.

#74
by annisamau - opened
MagicQuill/brushnet/powerpaint_utils.py CHANGED
@@ -51,27 +51,7 @@ class TokenizerWrapper:
51
  Args:
52
  tokens (Union[str, List[str]]): The tokens to be added.
53
  """
54
- # Check if tokens exist first to avoid assertion error in wrapped tokenizer
55
- # and to ensure idempotency in shared environments.
56
- if isinstance(tokens, str):
57
- tokens_to_check = [tokens]
58
- else:
59
- tokens_to_check = tokens
60
-
61
- # If all tokens are already in the vocabulary, skip adding them.
62
- # This relies on the wrapped tokenizer's behavior or checking its vocab.
63
- # Usually `add_tokens` returns 0 if all tokens exist.
64
- # We just want to avoid the assertion error if they are already added.
65
-
66
  num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
67
- if num_added_tokens == 0:
68
- # Check if they actually exist (idempotency case)
69
- # If they exist, we don't assert error, just return.
70
- # If they don't exist but add_tokens returned 0 (shouldn't happen for new tokens),
71
- # then we might have an issue.
72
- # For simplicity in fixing the leak/crash: if 0 added, assume they exist.
73
- return
74
-
75
  assert num_added_tokens != 0, (
76
  f"The tokenizer already contains the token {tokens}. Please pass "
77
  "a different `placeholder_token` that is not already in the "
@@ -102,11 +82,6 @@ class TokenizerWrapper:
102
  the added placeholder token.
103
  *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
104
  """
105
-
106
- # Check if already in token_map (idempotency)
107
- if placeholder_token in self.token_map:
108
- return
109
-
110
  output = []
111
  if num_vec_per_token == 1:
112
  self.try_adding_tokens(placeholder_token, *args, **kwargs)
@@ -301,29 +276,55 @@ class EmbeddingLayerWithFixes(nn.Module):
301
 
302
  def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
303
  """Add external embeddings to this layer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  """
305
  if isinstance(embeddings, dict):
306
  embeddings = [embeddings]
307
 
308
- # Idempotency check: filter out embeddings that are already present by name
309
- existing_names = {emb["name"] for emb in self.external_embeddings}
310
- new_embeddings = []
311
- for emb in embeddings:
312
- if emb["name"] not in existing_names:
313
- new_embeddings.append(emb)
314
- # Optional: Warn or check if existing embedding matches the new one?
315
- # For now, assume if name exists, it's the same token being re-added.
316
-
317
- if not new_embeddings:
318
- return
319
-
320
- self.external_embeddings += new_embeddings
321
  self.check_duplicate_names(self.external_embeddings)
322
  self.check_ids_overlap(self.external_embeddings)
323
 
324
  # set for trainable
325
  added_trainable_emb_info = []
326
- for embedding in new_embeddings:
327
  trainable = embedding.get("trainable", False)
328
  if trainable:
329
  name = embedding["name"]
@@ -331,7 +332,7 @@ class EmbeddingLayerWithFixes(nn.Module):
331
  self.trainable_embeddings[name] = embedding["embedding"]
332
  added_trainable_emb_info.append(name)
333
 
334
- added_emb_info = [emb["name"] for emb in new_embeddings]
335
  added_emb_info = ", ".join(added_emb_info)
336
  print(f"Successfully add external embeddings: {added_emb_info}.", "current")
337
 
@@ -459,39 +460,17 @@ def add_tokens(
459
  assert len(initialize_tokens) == len(
460
  placeholder_tokens
461
  ), "placeholder_token should be the same length as initialize_token"
462
-
463
- # Safe to call multiple times (idempotent)
464
  for ii in range(len(placeholder_tokens)):
465
  tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
466
 
467
  # text_encoder.set_embedding_layer()
468
  embedding_layer = text_encoder.text_model.embeddings.token_embedding
469
- if not isinstance(embedding_layer, EmbeddingLayerWithFixes):
470
- text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
471
- embedding_layer = text_encoder.text_model.embeddings.token_embedding
472
 
473
  assert embedding_layer is not None, (
474
  "Do not support get embedding layer for current text encoder. " "Please check your configuration."
475
  )
476
-
477
- # Only calculate initialization for tokens that are NOT already in the layer
478
- existing_names = {emb["name"] for emb in embedding_layer.external_embeddings}
479
- tokens_to_add = []
480
- init_tokens_to_add = []
481
-
482
- for ii, token in enumerate(placeholder_tokens):
483
- # This check assumes the placeholder token name matches the embedding name
484
- # TokenizerWrapper adds suffix _0, _1 etc if num_vec > 1.
485
- # The logic below handles generic case, but here we assume 1-to-1 or we check the main token.
486
- # Actually EmbeddingLayer uses specific names. TokenizerWrapper.add_placeholder_token generates them.
487
- # If num_vec_per_token > 1, TokenizerWrapper generates token_0, token_1...
488
- # Let's check if the embedding layer already has them.
489
-
490
- # The original code below generated embeddings for ALL input tokens.
491
- # add_embeddings will filter them out.
492
- # But we need to be careful not to re-initialize them differently if they exist.
493
- pass
494
-
495
  initialize_embedding = []
496
  if initialize_tokens is not None:
497
  for ii in range(len(placeholder_tokens)):
@@ -510,12 +489,8 @@ def add_tokens(
510
 
511
  token_info_all = []
512
  for ii in range(len(placeholder_tokens)):
513
- # get_token_info relies on the token being in tokenizer.
514
- # add_placeholder_token ensures it's there (idempotent now).
515
  token_info = tokenizer.get_token_info(placeholder_tokens[ii])
516
  token_info["embedding"] = initialize_embedding[ii]
517
  token_info["trainable"] = True
518
  token_info_all.append(token_info)
519
-
520
- # Idempotency is handled inside add_embeddings now
521
  embedding_layer.add_embeddings(token_info_all)
 
51
  Args:
52
  tokens (Union[str, List[str]]): The tokens to be added.
53
  """
 
 
 
 
 
 
 
 
 
 
 
 
54
  num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
 
 
 
 
 
 
 
 
55
  assert num_added_tokens != 0, (
56
  f"The tokenizer already contains the token {tokens}. Please pass "
57
  "a different `placeholder_token` that is not already in the "
 
82
  the added placeholder token.
83
  *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
84
  """
 
 
 
 
 
85
  output = []
86
  if num_vec_per_token == 1:
87
  self.try_adding_tokens(placeholder_token, *args, **kwargs)
 
276
 
277
  def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
278
  """Add external embeddings to this layer.
279
+
280
+ Use case:
281
+
282
+ >>> 1. Add token to tokenizer and get the token id.
283
+ >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
284
+ >>> # 'how much' in kiswahili
285
+ >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
286
+ >>>
287
+ >>> 2. Add external embeddings to the model.
288
+ >>> new_embedding = {
289
+ >>> 'name': 'ngapi', # 'how much' in kiswahili
290
+ >>> 'embedding': torch.ones(1, 15) * 4,
291
+ >>> 'start': tokenizer.get_token_info('kwaheri')['start'],
292
+ >>> 'end': tokenizer.get_token_info('kwaheri')['end'],
293
+ >>> 'trainable': False # if True, will registry as a parameter
294
+ >>> }
295
+ >>> embedding_layer = nn.Embedding(10, 15)
296
+ >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
297
+ >>> embedding_layer_wrapper.add_embeddings(new_embedding)
298
+ >>>
299
+ >>> 3. Forward tokenizer and embedding layer!
300
+ >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
301
+ >>> input_ids = tokenizer(
302
+ >>> input_text, padding='max_length', truncation=True,
303
+ >>> return_tensors='pt')['input_ids']
304
+ >>> out_feat = embedding_layer_wrapper(input_ids)
305
+ >>>
306
+ >>> 4. Let's validate the result!
307
+ >>> assert (out_feat[0, 3: 7] == 2.3).all()
308
+ >>> assert (out_feat[2, 5: 9] == 2.3).all()
309
+
310
+ Args:
311
+ embeddings (Union[dict, list[dict]]): The external embeddings to
312
+ be added. Each dict must contain the following 4 fields: 'name'
313
+ (the name of this embedding), 'embedding' (the embedding
314
+ tensor), 'start' (the start token id of this embedding), 'end'
315
+ (the end token id of this embedding). For example:
316
+ `{name: NAME, start: START, end: END, embedding: torch.Tensor}`
317
  """
318
  if isinstance(embeddings, dict):
319
  embeddings = [embeddings]
320
 
321
+ self.external_embeddings += embeddings
 
 
 
 
 
 
 
 
 
 
 
 
322
  self.check_duplicate_names(self.external_embeddings)
323
  self.check_ids_overlap(self.external_embeddings)
324
 
325
  # set for trainable
326
  added_trainable_emb_info = []
327
+ for embedding in embeddings:
328
  trainable = embedding.get("trainable", False)
329
  if trainable:
330
  name = embedding["name"]
 
332
  self.trainable_embeddings[name] = embedding["embedding"]
333
  added_trainable_emb_info.append(name)
334
 
335
+ added_emb_info = [emb["name"] for emb in embeddings]
336
  added_emb_info = ", ".join(added_emb_info)
337
  print(f"Successfully add external embeddings: {added_emb_info}.", "current")
338
 
 
460
  assert len(initialize_tokens) == len(
461
  placeholder_tokens
462
  ), "placeholder_token should be the same length as initialize_token"
 
 
463
  for ii in range(len(placeholder_tokens)):
464
  tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
465
 
466
  # text_encoder.set_embedding_layer()
467
  embedding_layer = text_encoder.text_model.embeddings.token_embedding
468
+ text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
469
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
 
470
 
471
  assert embedding_layer is not None, (
472
  "Do not support get embedding layer for current text encoder. " "Please check your configuration."
473
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  initialize_embedding = []
475
  if initialize_tokens is not None:
476
  for ii in range(len(placeholder_tokens)):
 
489
 
490
  token_info_all = []
491
  for ii in range(len(placeholder_tokens)):
 
 
492
  token_info = tokenizer.get_token_info(placeholder_tokens[ii])
493
  token_info["embedding"] = initialize_embedding[ii]
494
  token_info["trainable"] = True
495
  token_info_all.append(token_info)
 
 
496
  embedding_layer.add_embeddings(token_info_all)
MagicQuill/brushnet_nodes.py CHANGED
@@ -149,13 +149,7 @@ class PowerPaintCLIPLoader:
149
 
150
  print('PowerPaint base CLIP file: ', base_CLIP_file)
151
 
152
- # Reuse TokenizerWrapper if already wrapped
153
- if isinstance(pp_clip.tokenizer.clip_l.tokenizer, TokenizerWrapper):
154
- pp_tokenizer = pp_clip.tokenizer.clip_l.tokenizer
155
- else:
156
- pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
157
- pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
158
-
159
  pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
160
 
161
  add_tokens(
@@ -170,8 +164,7 @@ class PowerPaintCLIPLoader:
170
 
171
  print('PowerPaint CLIP file: ', pp_CLIP_file)
172
 
173
- # Already assigned above if new, or reused if existing
174
- # pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
175
  pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
176
 
177
  return (pp_clip,)
@@ -287,7 +280,7 @@ class PowerPaint:
287
  # unload vae and CLIPs
288
  del vae
289
  del clip
290
- for loaded_model in list(comfy.model_management.current_loaded_models):
291
  if type(loaded_model.model.model) in ModelsToUnload:
292
  comfy.model_management.current_loaded_models.remove(loaded_model)
293
  loaded_model.model_unload()
@@ -375,7 +368,7 @@ class BrushNet:
375
 
376
  # unload vae
377
  del vae
378
- for loaded_model in list(comfy.model_management.current_loaded_models):
379
  if type(loaded_model.model.model) in ModelsToUnload:
380
  comfy.model_management.current_loaded_models.remove(loaded_model)
381
  loaded_model.model_unload()
@@ -948,20 +941,6 @@ def brushnet_inference(x, timesteps, transformer_options, debug):
948
  debug=debug,
949
  )
950
 
951
- def forward_patched_by_brushnet(self, x, *args, **kwargs):
952
- h = self.original_forward(x, *args, **kwargs)
953
- if hasattr(self, 'add_sample_after') and type(self):
954
- to_add = self.add_sample_after
955
- if torch.is_tensor(to_add):
956
- # interpolate due to RAUNet
957
- if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
958
- to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
959
- h += to_add.to(h.dtype).to(h.device)
960
- else:
961
- h += self.add_sample_after
962
- self.add_sample_after = 0
963
- return h
964
-
965
 
966
  # This is main patch function
967
  def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
@@ -973,7 +952,7 @@ def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
973
  is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
974
 
975
  if is_SDXL:
976
- input_blocks = [[0, comfy.ops.manual_cast.Conv2d],
977
  [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
978
  [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
979
  [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
@@ -995,7 +974,7 @@ def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
995
  [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
996
  [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
997
  else:
998
- input_blocks = [[0, comfy.ops.manual_cast.Conv2d],
999
  [1, comfy.ldm.modules.attention.SpatialTransformer],
1000
  [2, comfy.ldm.modules.attention.SpatialTransformer],
1001
  [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
@@ -1080,33 +1059,36 @@ def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
1080
  bo['latent_id'] = 0
1081
 
1082
  # patch layers `forward` so we can apply brushnet
1083
- # Modified to prevent re-patching leak and closure creation
 
 
 
 
 
 
 
 
 
 
 
 
 
1084
  for i, block in enumerate(model.model.diffusion_model.input_blocks):
1085
  for j, layer in enumerate(block):
1086
  if not hasattr(layer, 'original_forward'):
1087
  layer.original_forward = layer.forward
1088
-
1089
- # Only patch if not already patched by us
1090
- if getattr(layer.forward, '__func__', None) != forward_patched_by_brushnet:
1091
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1092
-
1093
  layer.add_sample_after = 0
1094
 
1095
  for j, layer in enumerate(model.model.diffusion_model.middle_block):
1096
  if not hasattr(layer, 'original_forward'):
1097
  layer.original_forward = layer.forward
1098
-
1099
- if getattr(layer.forward, '__func__', None) != forward_patched_by_brushnet:
1100
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1101
-
1102
  layer.add_sample_after = 0
1103
 
1104
  for i, block in enumerate(model.model.diffusion_model.output_blocks):
1105
  for j, layer in enumerate(block):
1106
  if not hasattr(layer, 'original_forward'):
1107
  layer.original_forward = layer.forward
1108
-
1109
- if getattr(layer.forward, '__func__', None) != forward_patched_by_brushnet:
1110
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1111
-
1112
- layer.add_sample_after = 0
 
149
 
150
  print('PowerPaint base CLIP file: ', base_CLIP_file)
151
 
152
+ pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
 
 
 
 
 
 
153
  pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
154
 
155
  add_tokens(
 
164
 
165
  print('PowerPaint CLIP file: ', pp_CLIP_file)
166
 
167
+ pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
 
168
  pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
169
 
170
  return (pp_clip,)
 
280
  # unload vae and CLIPs
281
  del vae
282
  del clip
283
+ for loaded_model in comfy.model_management.current_loaded_models:
284
  if type(loaded_model.model.model) in ModelsToUnload:
285
  comfy.model_management.current_loaded_models.remove(loaded_model)
286
  loaded_model.model_unload()
 
368
 
369
  # unload vae
370
  del vae
371
+ for loaded_model in comfy.model_management.current_loaded_models:
372
  if type(loaded_model.model.model) in ModelsToUnload:
373
  comfy.model_management.current_loaded_models.remove(loaded_model)
374
  loaded_model.model_unload()
 
941
  debug=debug,
942
  )
943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
 
945
  # This is main patch function
946
  def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
 
952
  is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
953
 
954
  if is_SDXL:
955
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
956
  [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
957
  [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
958
  [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
 
974
  [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
975
  [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
976
  else:
977
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
978
  [1, comfy.ldm.modules.attention.SpatialTransformer],
979
  [2, comfy.ldm.modules.attention.SpatialTransformer],
980
  [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
 
1059
  bo['latent_id'] = 0
1060
 
1061
  # patch layers `forward` so we can apply brushnet
1062
+ def forward_patched_by_brushnet(self, x, *args, **kwargs):
1063
+ h = self.original_forward(x, *args, **kwargs)
1064
+ if hasattr(self, 'add_sample_after') and type(self):
1065
+ to_add = self.add_sample_after
1066
+ if torch.is_tensor(to_add):
1067
+ # interpolate due to RAUNet
1068
+ if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
1069
+ to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
1070
+ h += to_add.to(h.dtype).to(h.device)
1071
+ else:
1072
+ h += self.add_sample_after
1073
+ self.add_sample_after = 0
1074
+ return h
1075
+
1076
  for i, block in enumerate(model.model.diffusion_model.input_blocks):
1077
  for j, layer in enumerate(block):
1078
  if not hasattr(layer, 'original_forward'):
1079
  layer.original_forward = layer.forward
1080
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
 
 
 
 
1081
  layer.add_sample_after = 0
1082
 
1083
  for j, layer in enumerate(model.model.diffusion_model.middle_block):
1084
  if not hasattr(layer, 'original_forward'):
1085
  layer.original_forward = layer.forward
1086
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
 
 
 
1087
  layer.add_sample_after = 0
1088
 
1089
  for i, block in enumerate(model.model.diffusion_model.output_blocks):
1090
  for j, layer in enumerate(block):
1091
  if not hasattr(layer, 'original_forward'):
1092
  layer.original_forward = layer.forward
1093
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1094
+ layer.add_sample_after = 0
 
 
 
MagicQuill/comfy/cli_args.py CHANGED
@@ -59,7 +59,7 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
59
  fpunet_group = parser.add_mutually_exclusive_group()
60
  fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
61
  fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
62
- fpunet_group.add_argument("--fp8_e4m3fn-unet", type=bool, default=True, help="Store unet weights in fp8_e4m3fn.")
63
  fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
64
 
65
  fpvae_group = parser.add_mutually_exclusive_group()
 
59
  fpunet_group = parser.add_mutually_exclusive_group()
60
  fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
61
  fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
62
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
63
  fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
64
 
65
  fpvae_group = parser.add_mutually_exclusive_group()
MagicQuill/scribble_color_edit.py CHANGED
@@ -41,8 +41,7 @@ class ScribbleColorEditModel():
41
  self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint')
42
  print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint'))
43
  self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0]
44
-
45
- @torch.inference_mode()
46
  def process(self, ckpt_name, image, colored_image, positive_prompt, negative_prompt, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, base_model_version='SD1.5', dtype='float16', palette_resolution=2048):
47
  if ckpt_name != self.ckpt_name:
48
  self.ckpt_name = ckpt_name
 
41
  self.brushnet_loader.inpaint_files = get_files_with_extension('inpaint')
42
  print("self.brushnet_loader.inpaint_files: ", get_files_with_extension('inpaint'))
43
  self.brushnet = self.brushnet_loader.brushnet_loading(brushnet_name, dtype)[0]
44
+
 
45
  def process(self, ckpt_name, image, colored_image, positive_prompt, negative_prompt, mask, add_mask, remove_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, base_model_version='SD1.5', dtype='float16', palette_resolution=2048):
46
  if ckpt_name != self.ckpt_name:
47
  self.ckpt_name = ckpt_name
app.py CHANGED
@@ -5,7 +5,6 @@ subprocess.run(
5
  "pip install ./gradio_magicquill-0.0.1-py3-none-any.whl"
6
  )
7
  )
8
- import spaces
9
  import gradio as gr
10
  from gradio_magicquill import MagicQuill
11
  import random
@@ -14,7 +13,8 @@ import numpy as np
14
  from PIL import Image, ImageOps
15
  import base64
16
  import io
17
- from fastapi import Request
 
18
  from MagicQuill import folder_paths
19
  from MagicQuill.scribble_color_edit import ScribbleColorEditModel
20
  from gradio_client import Client, handle_file
@@ -23,18 +23,11 @@ import tempfile
23
  import cv2
24
  import os
25
  import requests
26
- import gc
27
 
28
  snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
29
  # HF_TOKEN = os.environ.get("HF_TOKEN")
30
  # The client has been made public. Welcome to duplicate our repo.
31
- _HELPER_SPACE = "LiuZichen/MagicQuillHelper"
32
- _client = None
33
- def client():
34
- global _client
35
- if _client is None:
36
- _client = Client(_HELPER_SPACE)
37
- return _client
38
  scribbleColorEditModel = ScribbleColorEditModel()
39
 
40
  def tensor_to_numpy(tensor):
@@ -136,11 +129,10 @@ def guess_prompt_handler(original_image, add_color_image, add_edge_image):
136
  add_color_image_file.close()
137
  add_edge_mask_file.close()
138
 
139
- res = client().predict(
140
  handle_file(original_image_file.name),
141
  handle_file(add_color_image_file.name),
142
- handle_file(add_edge_mask_file.name),
143
- api_name="/guess_prompt"
144
  )
145
 
146
  if original_image_file and os.path.exists(original_image_file.name):
@@ -185,20 +177,9 @@ def generate(ckpt_name, total_mask, original_image, add_color_image, add_edge_im
185
  )
186
 
187
  final_image_base64 = tensor_to_base64(final_image)
188
-
189
- del latent_samples, final_image, lineart_output, color_output
190
- gc.collect()
191
- torch.cuda.empty_cache()
192
-
193
  return final_image_base64
194
 
195
-
196
- from MagicQuill.comfy import model_management as _mm
197
-
198
-
199
- @spaces.GPU(duration=20)
200
  def generate_image_handler(x, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler):
201
- _mm.unload_all_models()
202
  if seed == -1:
203
  seed = random.randint(0, 2**32 - 1)
204
  ms_data = x['from_frontend']
@@ -225,7 +206,6 @@ with gr.Blocks(css=css, head=head) as demo:
225
  """
226
  # Welcome to MagicQuill! The paper has been accepted to CVPR 2025.
227
  Click the [link](https://magicquill.art) to view our demo and tutorial. Give us a [GitHub star](https://github.com/magic-quill/magicquill) if you are interested.
228
- MagicQuillV2 is available!!! Check our [demo](https://magicquill.art/v2/).
229
  """)
230
  with gr.Row(elem_classes="row"):
231
  ms = MagicQuill()
@@ -234,12 +214,11 @@ with gr.Blocks(css=css, head=head) as demo:
234
  btn = gr.Button("Run", variant="primary")
235
  with gr.Column():
236
  with gr.Accordion("parameters", open=False):
237
- ckpt_value = os.path.join('SD1.5', 'realisticVisionV60B1_v51VAE.safetensors')
238
  ckpt_name = gr.Dropdown(
239
- label="Base Model (fixed for demo)",
240
- choices=[ckpt_value],
241
- value=ckpt_value,
242
- interactive=False
243
  )
244
  negative_prompt = gr.Textbox(
245
  label="Negative Prompt",
@@ -333,19 +312,23 @@ with gr.Blocks(css=css, head=head) as demo:
333
  """)
334
  demo.queue(max_size=20, status_update_rate=0.1)
335
 
 
336
 
337
- if __name__ == "__main__":
338
- demo.launch(server_name="0.0.0.0", server_port=7860, prevent_thread_lock=True)
 
 
 
339
 
340
- @demo.app.post("/magic_quill/guess_prompt")
341
- async def guess_prompt(request: Request):
342
- data = await request.json()
343
- return guess_prompt_handler(data['original_image'], data['add_color_image'], data['add_edge_image'])
 
 
344
 
345
- @demo.app.post("/magic_quill/process_background_img")
346
- async def process_background_img(request: Request):
347
- img = await request.json()
348
- resized_img_tensor = load_and_resize_image(img)
349
- return "data:image/png;base64," + tensor_to_base64(resized_img_tensor)
350
 
351
- demo.block_thread()
 
 
 
5
  "pip install ./gradio_magicquill-0.0.1-py3-none-any.whl"
6
  )
7
  )
 
8
  import gradio as gr
9
  from gradio_magicquill import MagicQuill
10
  import random
 
13
  from PIL import Image, ImageOps
14
  import base64
15
  import io
16
+ from fastapi import FastAPI, Request
17
+ import uvicorn
18
  from MagicQuill import folder_paths
19
  from MagicQuill.scribble_color_edit import ScribbleColorEditModel
20
  from gradio_client import Client, handle_file
 
23
  import cv2
24
  import os
25
  import requests
 
26
 
27
  snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
28
  # HF_TOKEN = os.environ.get("HF_TOKEN")
29
  # The client has been made public. Welcome to duplicate our repo.
30
+ client = Client("LiuZichen/DrawNGuess")
 
 
 
 
 
 
31
  scribbleColorEditModel = ScribbleColorEditModel()
32
 
33
  def tensor_to_numpy(tensor):
 
129
  add_color_image_file.close()
130
  add_edge_mask_file.close()
131
 
132
+ res = client.predict(
133
  handle_file(original_image_file.name),
134
  handle_file(add_color_image_file.name),
135
+ handle_file(add_edge_mask_file.name)
 
136
  )
137
 
138
  if original_image_file and os.path.exists(original_image_file.name):
 
177
  )
178
 
179
  final_image_base64 = tensor_to_base64(final_image)
 
 
 
 
 
180
  return final_image_base64
181
 
 
 
 
 
 
182
  def generate_image_handler(x, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler):
 
183
  if seed == -1:
184
  seed = random.randint(0, 2**32 - 1)
185
  ms_data = x['from_frontend']
 
206
  """
207
  # Welcome to MagicQuill! The paper has been accepted to CVPR 2025.
208
  Click the [link](https://magicquill.art) to view our demo and tutorial. Give us a [GitHub star](https://github.com/magic-quill/magicquill) if you are interested.
 
209
  """)
210
  with gr.Row(elem_classes="row"):
211
  ms = MagicQuill()
 
214
  btn = gr.Button("Run", variant="primary")
215
  with gr.Column():
216
  with gr.Accordion("parameters", open=False):
 
217
  ckpt_name = gr.Dropdown(
218
+ label="Base Model Name",
219
+ choices=folder_paths.get_filename_list("checkpoints"),
220
+ value='SD1.5/realisticVisionV60B1_v51VAE.safetensors',
221
+ interactive=True
222
  )
223
  negative_prompt = gr.Textbox(
224
  label="Negative Prompt",
 
312
  """)
313
  demo.queue(max_size=20, status_update_rate=0.1)
314
 
315
+ app = FastAPI()
316
 
317
+ @app.post("/magic_quill/guess_prompt")
318
+ async def guess_prompt(request: Request):
319
+ data = await request.json()
320
+ res = guess_prompt_handler(data['original_image'], data['add_color_image'], data['add_edge_image'])
321
+ return res
322
 
323
+ @app.post("/magic_quill/process_background_img")
324
+ async def process_background_img(request: Request):
325
+ img = await request.json()
326
+ resized_img_tensor = load_and_resize_image(img)
327
+ resized_img_base64 = "data:image/png;base64," + tensor_to_base64(resized_img_tensor)
328
+ return resized_img_base64
329
 
330
+ app = gr.mount_gradio_app(app, demo, "/")
 
 
 
 
331
 
332
+ if __name__ == "__main__":
333
+ uvicorn.run(app, host="0.0.0.0", port=7860)
334
+ # demo.launch()
requirements.txt CHANGED
@@ -14,7 +14,7 @@ anyio==4.4.0
14
  async-timeout==4.0.3
15
  attrs==23.2.0
16
  beautifulsoup4==4.12.3
17
- bitsandbytes
18
  certifi==2024.7.4
19
  cffi==1.16.0
20
  chardet==5.2.0
@@ -33,7 +33,7 @@ einops-exts==0.0.4
33
  embreex==2.17.7.post5
34
  eval-type-backport==0.2.0
35
  exceptiongroup==1.2.2
36
- fastapi<0.112
37
  ffmpy==0.4.0
38
  filelock==3.15.4
39
  flatbuffers==24.3.25
@@ -132,11 +132,11 @@ sounddevice==0.4.7
132
  soupsieve==2.5
133
  spandrel==0.3.4
134
  stanza==1.1.1
135
- starlette<0.38
136
  svg-path==6.3
137
  svglib==1.5.1
138
  svgwrite==1.4.3
139
- sympy==1.13.3
140
  tabulate==0.9.0
141
  termcolor==2.4.0
142
  threadpoolctl==3.5.0
@@ -151,8 +151,7 @@ tqdm==4.66.5
151
  trampoline==0.1.2
152
  transformers==4.37.2
153
  trimesh==4.4.3
154
- torch==2.8.0
155
- torchvision==0.23.0
156
  torchsde==0.2.6
157
  typer==0.12.5
158
  typing-extensions==4.12.2
 
14
  async-timeout==4.0.3
15
  attrs==23.2.0
16
  beautifulsoup4==4.12.3
17
+ bitsandbytes==0.43.3
18
  certifi==2024.7.4
19
  cffi==1.16.0
20
  chardet==5.2.0
 
33
  embreex==2.17.7.post5
34
  eval-type-backport==0.2.0
35
  exceptiongroup==1.2.2
36
+ fastapi
37
  ffmpy==0.4.0
38
  filelock==3.15.4
39
  flatbuffers==24.3.25
 
132
  soupsieve==2.5
133
  spandrel==0.3.4
134
  stanza==1.1.1
135
+ starlette
136
  svg-path==6.3
137
  svglib==1.5.1
138
  svgwrite==1.4.3
139
+ sympy==1.13.1
140
  tabulate==0.9.0
141
  termcolor==2.4.0
142
  threadpoolctl==3.5.0
 
151
  trampoline==0.1.2
152
  transformers==4.37.2
153
  trimesh==4.4.3
154
+ triton==2.1.0
 
155
  torchsde==0.2.6
156
  typer==0.12.5
157
  typing-extensions==4.12.2