primerz commited on
Commit
589234e
·
verified ·
1 Parent(s): e4dd0ff

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +34 -31
generator.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from config import Config
3
- from utils import resize_image_to_1mp, get_caption
4
  from PIL import Image
5
 
6
  class Generator:
@@ -19,16 +19,12 @@ class Generator:
19
 
20
  # Generate lineart map
21
  lineart_map_raw = self.mh.lineart_anime_detector(image)
22
-
23
- # --- MODIFIED: Removed tile map ---
24
- # --- END MODIFIED ---
25
 
26
  # Manually resize maps to match the exact output resolution
27
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
28
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
29
- # tile_map = tile_map_raw.resize((width, height), Image.LANCZOS) # <-- REMOVED
30
 
31
- return depth_map, lineart_map # <-- MODIFIED
32
 
33
  def predict(
34
  self,
@@ -40,7 +36,6 @@ class Generator:
40
  img2img_strength=0.3,
41
  depth_strength=0.3,
42
  lineart_strength=0.3,
43
- # tile_strength=0.7, # <-- REMOVED
44
  seed=-1
45
  ):
46
  # 1. Pre-process Inputs
@@ -48,8 +43,8 @@ class Generator:
48
  processed_image = resize_image_to_1mp(input_image)
49
  target_width, target_height = processed_image.size
50
 
51
- # 2. Get Face Embedding (Robust Mode)
52
- face_emb = self.mh.get_face_embedding(processed_image)
53
 
54
  # 3. Generate Prompt
55
  if not user_prompt.strip():
@@ -65,43 +60,51 @@ class Generator:
65
  print(f"Prompt: {final_prompt}")
66
  print(f"Negative Prompt: {negative_prompt}")
67
 
68
- # 4. Generate Control Maps (Structure)
69
- print("Generating Control Maps (Depth, LineArt)...") # <-- MODIFIED
70
- depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height) # <-- MODIFIED
71
 
72
- # 5. Logic for Face vs No-Face
73
- # --- MODIFIED: Removed Tile Control ---
74
- # ControlNet order: [InstantID, Zoe, LineArt] # <-- MODIFIED
75
- if face_emb is not None:
76
- print("Face detected: Applying InstantID.")
77
- controlnet_conditioning_scale = [0.45, depth_strength, lineart_strength] # <-- MODIFIED
78
- control_guidance_end = [0.3, 0.6, 0.6] # <-- MODIFIED
79
- self.mh.pipeline.set_ip_adapter_scale(0.45)
 
 
 
 
80
  else:
81
  print("No face detected: Disabling InstantID.")
82
- controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] # <-- MODIFIED
83
- control_guidance_end = [0.3, 0.6, 0.6] # <-- MODIFIED
84
- self.mh.pipeline.set_ip_adapter_scale(0.0)
85
-
86
- # --- START FIX for NoneType Error ---
87
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
88
- # --- END FIX ---
 
 
 
 
 
 
 
 
89
 
90
- # --- ADDED: Seed/Generator Logic ---
91
  if seed == -1 or seed is None:
92
  seed = torch.Generator().seed()
93
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
94
  print(f"Using seed: {seed}")
95
- # --- END ADDED ---
96
 
97
  # 6. Run Inference
98
  print("Running pipeline...")
99
  result = self.mh.pipeline(
100
  prompt=final_prompt,
101
  negative_prompt=negative_prompt,
102
- image=processed_image,
103
- control_image=[processed_image, depth_map, lineart_map], # <-- MODIFIED
104
- image_embeds=face_emb,
105
  generator=generator,
106
 
107
  # --- Parameters from UI ---
 
1
  import torch
2
  from config import Config
3
+ from utils import resize_image_to_1mp, get_caption, draw_kps # <-- MODIFIED
4
  from PIL import Image
5
 
6
  class Generator:
 
19
 
20
  # Generate lineart map
21
  lineart_map_raw = self.mh.lineart_anime_detector(image)
 
 
 
22
 
23
  # Manually resize maps to match the exact output resolution
24
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
25
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
 
26
 
27
+ return depth_map, lineart_map # <-- MODIFIED (kps is now handled in predict)
28
 
29
  def predict(
30
  self,
 
36
  img2img_strength=0.3,
37
  depth_strength=0.3,
38
  lineart_strength=0.3,
 
39
  seed=-1
40
  ):
41
  # 1. Pre-process Inputs
 
43
  processed_image = resize_image_to_1mp(input_image)
44
  target_width, target_height = processed_image.size
45
 
46
+ # 2. Get Face Info (replaces get_face_embedding)
47
+ face_info = self.mh.get_face_info(processed_image)
48
 
49
  # 3. Generate Prompt
50
  if not user_prompt.strip():
 
60
  print(f"Prompt: {final_prompt}")
61
  print(f"Negative Prompt: {negative_prompt}")
62
 
63
+ # 4. Generate OTHER Control Maps (Structure)
64
+ print("Generating Control Maps (Depth, LineArt)...")
65
+ depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
66
 
67
+ # 5. Logic for Face vs No-Face (NOW INCLUDES KPS)
68
+ # ControlNet order: [InstantID_KPS, Zoe_Depth, LineArt]
69
+ if face_info is not None:
70
+ print("Face detected: Applying InstantID with keypoints.")
71
+ # Get embedding
72
+ face_emb = torch.tensor(face_info.normed_embedding).unsqueeze(0)
73
+ # Create keypoint image
74
+ face_kps = draw_kps(processed_image, face_info['kps'])
75
+
76
+ # Set strengths (using 0.8 from file's example)
77
+ controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
78
+ self.mh.pipeline.set_ip_adapter_scale(0.8)
79
  else:
80
  print("No face detected: Disabling InstantID.")
81
+ # Create dummy embedding
 
 
 
 
82
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
83
+ # Create dummy keypoint image (black)
84
+ face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
85
+
86
+ # Set strengths
87
+ controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
88
+ self.mh.pipeline.set_ip_adapter_scale(0.0)
89
+
90
+ # We keep the guidance_end for pose low
91
+ control_guidance_end = [0.3, 0.6, 0.6]
92
 
93
+ # --- Seed/Generator Logic ---
94
  if seed == -1 or seed is None:
95
  seed = torch.Generator().seed()
96
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
97
  print(f"Using seed: {seed}")
98
+ # --- END ---
99
 
100
  # 6. Run Inference
101
  print("Running pipeline...")
102
  result = self.mh.pipeline(
103
  prompt=final_prompt,
104
  negative_prompt=negative_prompt,
105
+ image=processed_image, # Base img2img image
106
+ control_image=[face_kps, depth_map, lineart_map], # <-- MODIFIED
107
+ image_embeds=face_emb, # Face identity embedding
108
  generator=generator,
109
 
110
  # --- Parameters from UI ---