primerz commited on
Commit
3e3e641
·
verified ·
1 Parent(s): 6d5987b

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +30 -24
generator.py CHANGED
@@ -1,18 +1,23 @@
1
  import torch
2
  from config import Config
3
- from utils import resize_image_to_1mp, get_caption, prepare_control_images
4
 
5
  class Generator:
6
  def __init__(self, model_handler):
7
  self.mh = model_handler
8
 
 
 
 
 
 
 
9
  def predict(self, input_image, user_prompt=""):
10
  # 1. Pre-process Inputs
11
  print("Processing Input...")
12
  processed_image = resize_image_to_1mp(input_image)
13
 
14
  # 2. Get Face Embedding (Robust Mode)
15
- # Now returns None instead of crashing if no face is found
16
  face_emb = self.mh.get_face_embedding(processed_image)
17
 
18
  # 3. Generate Prompt
@@ -25,49 +30,50 @@ class Generator:
25
  print(f"Prompt: {final_prompt}")
26
 
27
  # 4. Generate Control Maps (Structure)
28
- depth_map, lineart_map = prepare_control_images(processed_image, self.mh.zoe_detector, self.mh.lineart_detector)
 
29
 
30
  # 5. Logic for Face vs No-Face
31
  if face_emb is not None:
32
  print("Face detected: Applying InstantID.")
33
- # [InstantID, Zoe, LineArt]
34
- # Stop InstantID at 50% to allow pixelation
 
35
  controlnet_conditioning_scale = [0.6, 0.4, 0.4]
 
 
36
  control_guidance_end = [0.5, 0.8, 0.8]
 
 
37
  ip_adapter_scale = 0.9
38
 
39
- # InstantID requires the face embedding usually via IP-adapter input
40
- # We pass the processed image to ip_adapter_image (library handles crop internally usually,
41
- # or we rely on the embedding we extracted if using custom pipeline.
42
- # Standard diffusers IP adapter uses the image).
43
- ip_image = processed_image
44
  else:
45
  print("No face detected: Disabling InstantID, using only Structure+Style.")
46
  # Disable InstantID (Weight 0.0)
47
  controlnet_conditioning_scale = [0.0, 0.4, 0.4]
48
  control_guidance_end = [0.0, 0.8, 0.8]
49
  ip_adapter_scale = 0.0
50
- # Pass generic image to satisfy input requirement, but scale is 0 so it's ignored
51
- ip_image = processed_image
52
-
53
- # Set IP Adapter Scale
54
- self.mh.pipeline.set_ip_adapter_scale(ip_adapter_scale)
55
 
56
  # 6. Run Inference
 
57
  result = self.mh.pipeline(
58
  prompt=final_prompt,
59
- # We pass the image list corresponding to [InstantID, Zoe, LineArt]
60
- # Even if InstantID weight is 0, we must pass an image to keep list length correct.
61
- image=[processed_image, depth_map, lineart_map],
62
-
63
- # IP Adapter input
64
- ip_adapter_image=[ip_image],
65
 
 
 
66
  controlnet_conditioning_scale=controlnet_conditioning_scale,
67
  control_guidance_end=control_guidance_end,
68
- num_inference_steps=8, # LCM is fast
69
- guidance_scale=1.5, # LCM needs low CFG
70
- cross_attention_kwargs={"scale": 1.0}
 
 
71
  ).images[0]
72
 
73
  return result
 
1
  import torch
2
  from config import Config
3
+ from utils import resize_image_to_1mp, get_caption
4
 
5
  class Generator:
6
  def __init__(self, model_handler):
7
  self.mh = model_handler
8
 
9
+ def prepare_control_images(self, image):
10
+ """Generates the conditioning maps from the input image."""
11
+ depth_map = self.mh.zoe_detector(image)
12
+ lineart_map = self.mh.lineart_detector(image)
13
+ return depth_map, lineart_map
14
+
15
  def predict(self, input_image, user_prompt=""):
16
  # 1. Pre-process Inputs
17
  print("Processing Input...")
18
  processed_image = resize_image_to_1mp(input_image)
19
 
20
  # 2. Get Face Embedding (Robust Mode)
 
21
  face_emb = self.mh.get_face_embedding(processed_image)
22
 
23
  # 3. Generate Prompt
 
30
  print(f"Prompt: {final_prompt}")
31
 
32
  # 4. Generate Control Maps (Structure)
33
+ print("Generating Control Maps (Depth, LineArt)...")
34
+ depth_map, lineart_map = self.prepare_control_images(processed_image)
35
 
36
  # 5. Logic for Face vs No-Face
37
  if face_emb is not None:
38
  print("Face detected: Applying InstantID.")
39
+ # [InstantID, Zoe, LineArt] (Must match load order in model.py)
40
+
41
+ # SCALE: InstantID Medium (0.6), Zoe Low (0.4), LineArt Low (0.4)
42
  controlnet_conditioning_scale = [0.6, 0.4, 0.4]
43
+
44
+ # STOP: InstantID stops EARLY (50%) to allow pixelation
45
  control_guidance_end = [0.5, 0.8, 0.8]
46
+
47
+ # IP Adapter Scale (Likeness): Keep High
48
  ip_adapter_scale = 0.9
49
 
50
+ # We must pass the face embedding and the image for the IP-Adapter
51
+ ip_adapter_image = processed_image
52
+ prompt_embeds, _ = self.mh.pipeline.ip_adapter.get_prompt_embeds(ip_adapter_image, face_emb, None)
53
+
 
54
  else:
55
  print("No face detected: Disabling InstantID, using only Structure+Style.")
56
  # Disable InstantID (Weight 0.0)
57
  controlnet_conditioning_scale = [0.0, 0.4, 0.4]
58
  control_guidance_end = [0.0, 0.8, 0.8]
59
  ip_adapter_scale = 0.0
60
+ prompt_embeds = None # No face embedding
 
 
 
 
61
 
62
  # 6. Run Inference
63
+ print("Running pipeline...")
64
  result = self.mh.pipeline(
65
  prompt=final_prompt,
66
+ prompt_embeds=prompt_embeds,
 
 
 
 
 
67
 
68
+ # ControlNet inputs
69
+ image=[processed_image, depth_map, lineart_map], # List for [ID, Zoe, LineArt]
70
  controlnet_conditioning_scale=controlnet_conditioning_scale,
71
  control_guidance_end=control_guidance_end,
72
+
73
+ # LCM settings
74
+ num_inference_steps=8,
75
+ guidance_scale=1.5,
76
+
77
  ).images[0]
78
 
79
  return result