primerz commited on
Commit
ff014fd
·
verified ·
1 Parent(s): 2799929

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +39 -139
generator.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from config import Config
3
- from utils import get_caption, draw_kps
4
  from PIL import Image
5
 
6
  class Generator:
@@ -15,23 +15,23 @@ class Generator:
15
  w, h = image.size
16
  aspect_ratio = w / h
17
 
18
- # 1. Determine Target Resolution (SDXL Buckets)
19
  if 0.85 <= aspect_ratio <= 1.15:
20
  target_w, target_h = 1024, 1024
21
  print(f"Snap to Bucket: Square (1024x1024)")
22
  elif aspect_ratio < 0.85:
23
  if aspect_ratio < 0.72:
24
- target_w, target_h = 832, 1216 # Tall Portrait
25
  print(f"Snap to Bucket: Tall Portrait (832x1216)")
26
  else:
27
- target_w, target_h = 896, 1152 # Standard Portrait
28
  print(f"Snap to Bucket: Portrait (896x1152)")
29
- else: # aspect_ratio > 1.15
30
  if aspect_ratio > 1.35:
31
- target_w, target_h = 1216, 832 # Wide Landscape
32
  print(f"Snap to Bucket: Wide Landscape (1216x832)")
33
  else:
34
- target_w, target_h = 1152, 896 # Standard Landscape
35
  print(f"Snap to Bucket: Landscape (1152x896)")
36
 
37
  # 2. Center Crop to Target Aspect Ratio
@@ -52,93 +52,33 @@ class Generator:
52
  final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
53
  return final_img
54
 
55
- def prepare_control_images(
56
- self,
57
- image,
58
- width,
59
- height,
60
- edge_type=None,
61
- canny_low=100,
62
- canny_high=200
63
- ):
64
  """
65
- Generates conditioning maps based on edge_type.
66
-
67
- Returns:
68
- tuple: (depth_map, edge_maps_list) where edge_maps_list matches the ControlNet setup
69
  """
70
- if edge_type is None:
71
- edge_type = self.mh.edge_type
72
-
73
- print(f"Generating control maps ({edge_type}) for {width}x{height}...")
74
-
75
- # Always generate depth
76
- depth_map_raw = self.mh.extract_depth(image)
77
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
78
-
79
- edge_maps = []
80
-
81
- if edge_type == "canny":
82
- canny_map_raw = self.mh.extract_canny(image, canny_low, canny_high)
83
- canny_map = canny_map_raw.resize((width, height), Image.LANCZOS)
84
- edge_maps.append(canny_map)
85
- print(f" ✓ Canny edges generated")
86
-
87
- elif edge_type == "lineart":
88
- lineart_map_raw = self.mh.extract_lineart(image)
89
- lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
90
- edge_maps.append(lineart_map)
91
- print(f" ✓ LineArt edges generated")
92
-
93
- elif edge_type == "both":
94
- canny_map_raw = self.mh.extract_canny(image, canny_low, canny_high)
95
- canny_map = canny_map_raw.resize((width, height), Image.LANCZOS)
96
- edge_maps.append(canny_map)
97
-
98
- lineart_map_raw = self.mh.extract_lineart(image)
99
- lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
100
- edge_maps.append(lineart_map)
101
- print(f" ✓ Both Canny and LineArt generated")
102
-
103
- return depth_map, edge_maps
104
 
105
  def predict(
106
  self,
107
  input_image,
108
  user_prompt="",
109
  negative_prompt="",
110
- guidance_scale=4.0,
111
- num_inference_steps=8,
112
- img2img_strength=0.9,
113
- depth_strength=0.3,
114
- edge_strength=0.3,
115
- instantid_strength=0.8,
116
- canny_low_threshold=100,
117
- canny_high_threshold=200,
118
- eta=0.45,
119
- seed=-1,
120
- return_control_images=False
121
  ):
122
- """
123
- Enhanced prediction with more control options.
124
-
125
- Args:
126
- input_image: PIL Image
127
- user_prompt: Text prompt (optional, will auto-caption if empty)
128
- negative_prompt: Negative prompt
129
- guidance_scale: CFG scale (4.0 recommended for TCD + LoRA)
130
- num_inference_steps: Number of steps (4-12 for TCD)
131
- img2img_strength: Denoising strength
132
- depth_strength: Depth ControlNet strength
133
- edge_strength: Edge ControlNet strength (canny/lineart)
134
- instantid_strength: Face preservation strength
135
- canny_low_threshold: Canny low threshold (if using canny)
136
- canny_high_threshold: Canny high threshold (if using canny)
137
- eta: TCD stochasticity parameter
138
- seed: Random seed (-1 for random)
139
- return_control_images: Return control images for debugging
140
- """
141
- # 1. Pre-process Inputs
142
  print("Processing Input...")
143
  processed_image = self.smart_crop_and_resize(input_image)
144
  target_width, target_height = processed_image.size
@@ -161,92 +101,52 @@ class Generator:
161
  print(f"Negative Prompt: {negative_prompt}")
162
 
163
  # 4. Generate Control Maps
164
- print("Generating Control Maps...")
165
- depth_map, edge_maps = self.prepare_control_images(
166
- processed_image,
167
- target_width,
168
- target_height,
169
- canny_low=canny_low_threshold,
170
- canny_high=canny_high_threshold
171
- )
172
-
173
- # 5. Setup conditioning based on face detection
174
- control_images = []
175
- conditioning_scales = []
176
- control_guidance_end = []
177
 
 
178
  if face_info is not None:
179
- print(f"Face detected: Applying InstantID (strength: {instantid_strength})")
180
  face_emb = torch.tensor(
181
  face_info['embedding'],
182
  dtype=Config.DTYPE,
183
  device=Config.DEVICE
184
  ).unsqueeze(0)
185
  face_kps = draw_kps(processed_image, face_info['kps'])
186
-
187
- # Add face keypoints
188
- control_images.append(face_kps)
189
- conditioning_scales.append(instantid_strength)
190
- control_guidance_end.append(0.3)
191
-
192
- # Set IP-Adapter scale for face
193
- self.mh.pipeline.set_ip_adapter_scale(instantid_strength)
194
  else:
195
- print("No face detected: Disabling InstantID")
196
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
197
  face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
198
-
199
- # Add placeholder face keypoints
200
- control_images.append(face_kps)
201
- conditioning_scales.append(0.0)
202
- control_guidance_end.append(0.6)
203
-
204
  self.mh.pipeline.set_ip_adapter_scale(0.0)
205
-
206
- # Add depth map
207
- control_images.append(depth_map)
208
- conditioning_scales.append(depth_strength)
209
- control_guidance_end.append(0.6)
210
-
211
- # Add edge map(s)
212
- for edge_map in edge_maps:
213
- control_images.append(edge_map)
214
- conditioning_scales.append(edge_strength)
215
- control_guidance_end.append(0.6)
216
 
217
- # 6. Setup seed
 
218
  if seed == -1 or seed is None:
219
  seed = torch.Generator().seed()
220
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
221
  print(f"Using seed: {seed}")
222
 
223
- # 7. Run Inference
224
- print(f"Running pipeline (steps: {num_inference_steps}, cfg: {guidance_scale}, eta: {eta})...")
225
  result = self.mh.pipeline(
226
  prompt=final_prompt,
227
  negative_prompt=negative_prompt,
228
  image=processed_image,
229
- control_image=control_images,
230
  image_embeds=face_emb,
231
  generator=generator,
232
 
233
  strength=img2img_strength,
234
  guidance_scale=guidance_scale,
235
- num_inference_steps=num_inference_steps,
236
 
237
- controlnet_conditioning_scale=conditioning_scales,
238
  control_guidance_end=control_guidance_end,
239
  clip_skip=0,
240
 
241
- eta=eta,
242
  ).images[0]
243
 
244
- if return_control_images:
245
- return result, {
246
- 'depth': depth_map,
247
- 'edges': edge_maps,
248
- 'face_kps': face_kps if face_info else None,
249
- 'processed_input': processed_image
250
- }
251
-
252
  return result
 
1
  import torch
2
  from config import Config
3
+ from utils import get_caption, draw_kps # Removed resize_image_to_1mp
4
  from PIL import Image
5
 
6
  class Generator:
 
15
  w, h = image.size
16
  aspect_ratio = w / h
17
 
18
+ # 1. Determine Target Resolution (Horizon SDXL Buckets)
19
  if 0.85 <= aspect_ratio <= 1.15:
20
  target_w, target_h = 1024, 1024
21
  print(f"Snap to Bucket: Square (1024x1024)")
22
  elif aspect_ratio < 0.85:
23
  if aspect_ratio < 0.72:
24
+ target_w, target_h = 832, 1216 # Tall Portrait
25
  print(f"Snap to Bucket: Tall Portrait (832x1216)")
26
  else:
27
+ target_w, target_h = 896, 1152 # Standard Portrait
28
  print(f"Snap to Bucket: Portrait (896x1152)")
29
+ else: # aspect_ratio > 1.15
30
  if aspect_ratio > 1.35:
31
+ target_w, target_h = 1216, 832 # Wide Landscape
32
  print(f"Snap to Bucket: Wide Landscape (1216x832)")
33
  else:
34
+ target_w, target_h = 1152, 896 # Standard Landscape
35
  print(f"Snap to Bucket: Landscape (1152x896)")
36
 
37
  # 2. Center Crop to Target Aspect Ratio
 
52
  final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
53
  return final_img
54
 
55
+ def prepare_control_images(self, image, width, height):
 
 
 
 
 
 
 
 
56
  """
57
+ Generates conditioning maps, ensuring they are resized
58
+ to the exact target dimensions (width, height).
 
 
59
  """
60
+ print(f"Generating control maps for {width}x{height}...")
61
+ depth_map_raw = self.mh.leres_detector(image)
62
+ lineart_map_raw = self.mh.lineart_anime_detector(image)
 
 
 
 
63
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
64
+ lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
65
+ return depth_map, lineart_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def predict(
68
  self,
69
  input_image,
70
  user_prompt="",
71
  negative_prompt="",
72
+ # --- DPMSolver++ Optimized Defaults ---
73
+ guidance_scale=7.0,
74
+ num_inference_steps=20,
75
+ img2img_strength=0.85,
76
+ # ----------------------------
77
+ depth_strength=0.8,
78
+ lineart_strength=0.8,
79
+ seed=-1
 
 
 
80
  ):
81
+ # 1. Pre-process Inputs (Using Smart Crop)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  print("Processing Input...")
83
  processed_image = self.smart_crop_and_resize(input_image)
84
  target_width, target_height = processed_image.size
 
101
  print(f"Negative Prompt: {negative_prompt}")
102
 
103
  # 4. Generate Control Maps
104
+ print("Generating Control Maps (Depth, LineArt)...")
105
+ depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # 5. Logic for Face vs No-Face
108
  if face_info is not None:
109
+ print("Face detected: Applying InstantID with keypoints.")
110
  face_emb = torch.tensor(
111
  face_info['embedding'],
112
  dtype=Config.DTYPE,
113
  device=Config.DEVICE
114
  ).unsqueeze(0)
115
  face_kps = draw_kps(processed_image, face_info['kps'])
116
+ controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
117
+ self.mh.pipeline.set_ip_adapter_scale(0.8)
 
 
 
 
 
 
118
  else:
119
+ print("No face detected: Disabling InstantID.")
120
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
121
  face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
122
+ controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
 
 
 
 
 
123
  self.mh.pipeline.set_ip_adapter_scale(0.0)
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ control_guidance_end = [0.3, 0.6, 0.6]
126
+
127
  if seed == -1 or seed is None:
128
  seed = torch.Generator().seed()
129
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
130
  print(f"Using seed: {seed}")
131
 
132
+ # 6. Run Inference
133
+ print("Running pipeline...")
134
  result = self.mh.pipeline(
135
  prompt=final_prompt,
136
  negative_prompt=negative_prompt,
137
  image=processed_image,
138
+ control_image=[face_kps, depth_map, lineart_map],
139
  image_embeds=face_emb,
140
  generator=generator,
141
 
142
  strength=img2img_strength,
143
  guidance_scale=guidance_scale,
144
+ num_inference_steps=num_inference_steps,
145
 
146
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
147
  control_guidance_end=control_guidance_end,
148
  clip_skip=0,
149
 
 
150
  ).images[0]
151
 
 
 
 
 
 
 
 
 
152
  return result