primerz commited on
Commit
5cf276c
·
verified ·
1 Parent(s): f5bcb07

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +21 -101
generator.py CHANGED
@@ -1,85 +1,29 @@
1
  import torch
2
  from config import Config
3
- from utils import get_caption, draw_kps
4
  from PIL import Image
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
10
- def solve_bezier(self, t, p0, p1, p2, p3):
11
- """
12
- Calculates a point on a cubic Bezier curve for a given t (0 to 1).
13
- """
14
- t = max(0.0, min(1.0, t))
15
- term0 = (1 - t)**3 * p0
16
- term1 = 3 * (1 - t)**2 * t * p1
17
- term2 = 3 * (1 - t) * t**2 * p2
18
- term3 = t**3 * p3
19
- return term0 + term1 + term2 + term3
20
-
21
- def smart_crop_and_resize(self, image):
22
  """
23
- Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
24
- Performs a center crop to match the target ratio, then resizes.
25
  """
26
- w, h = image.size
27
- aspect_ratio = w / h
28
-
29
- # 1. Determine Target Resolution (Horizon SDXL Buckets)
30
- if 0.85 <= aspect_ratio <= 1.15:
31
- # Square-ish -> 1024x1024
32
- target_w, target_h = 1024, 1024
33
- print(f"Snap to Bucket: Square (1024x1024)")
34
-
35
- elif aspect_ratio < 0.85:
36
- # Portrait
37
- # Decide between 896x1152 (AR ~0.77) and 832x1216 (AR ~0.68)
38
- if aspect_ratio < 0.72:
39
- target_w, target_h = 832, 1216 # Tall Portrait
40
- print(f"Snap to Bucket: Tall Portrait (832x1216)")
41
- else:
42
- target_w, target_h = 896, 1152 # Standard Portrait
43
- print(f"Snap to Bucket: Portrait (896x1152)")
44
-
45
- else: # aspect_ratio > 1.15
46
- # Landscape
47
- # Decide between 1152x896 (AR ~1.28) and 1216x832 (AR ~1.46)
48
- if aspect_ratio > 1.35:
49
- target_w, target_h = 1216, 832 # Wide Landscape
50
- print(f"Snap to Bucket: Wide Landscape (1216x832)")
51
- else:
52
- target_w, target_h = 1152, 896 # Standard Landscape
53
- print(f"Snap to Bucket: Landscape (1152x896)")
54
-
55
- # 2. Center Crop to Target Aspect Ratio
56
- target_ar = target_w / target_h
57
-
58
- if aspect_ratio > target_ar:
59
- # Image is wider than target -> Crop width (cut sides)
60
- new_w = int(h * target_ar)
61
- offset = (w - new_w) // 2
62
- crop_box = (offset, 0, offset + new_w, h)
63
- else:
64
- # Image is taller than target -> Crop height (cut top/bottom)
65
- new_h = int(w / target_ar)
66
- offset = (h - new_h) // 2
67
- crop_box = (0, offset, w, offset + new_h)
68
-
69
- cropped_img = image.crop(crop_box)
70
-
71
- # 3. Resize to Exact Target Resolution
72
- final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
73
- return final_img
74
-
75
- def prepare_control_images(self, image, width, height):
76
  print(f"Generating control maps for {width}x{height}...")
 
 
77
  depth_map_raw = self.mh.leres_detector(image)
 
 
78
  lineart_map_raw = self.mh.lineart_anime_detector(image)
79
 
80
- # Maps are resized to match the exact bucket resolution
81
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
82
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
 
83
  return depth_map, lineart_map
84
 
85
  def predict(
@@ -94,39 +38,15 @@ class Generator:
94
  lineart_strength=0.3,
95
  seed=-1
96
  ):
97
- # 1. Pre-process Inputs (New Smart Crop)
98
  print("Processing Input...")
99
- processed_image = self.smart_crop_and_resize(input_image)
 
100
  target_width, target_height = processed_image.size
101
 
102
  # 2. Get Face Info
103
  face_info = self.mh.get_face_info(processed_image)
104
 
105
- # --- CUBIC BEZIER ADAPTIVE LOGIC ---
106
- adaptive_cfg = guidance_scale
107
- adaptive_strength = img2img_strength
108
-
109
- if face_info is not None:
110
- # 1. Calculate Face Coverage (t)
111
- bbox = face_info['bbox']
112
- face_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
113
- total_area = target_width * target_height
114
- coverage_ratio = face_area / total_area
115
-
116
- print(f"Face Coverage: {coverage_ratio:.3f} ({int(coverage_ratio * 12)}/12)")
117
-
118
- # 2. Define Control Points (Half Less Aggressive)
119
- cfg_mult = self.solve_bezier(coverage_ratio, 0.825, 0.85, 0.95, 1.0)
120
- str_mult = self.solve_bezier(coverage_ratio, 0.9375, 0.95, 0.99, 1.0)
121
-
122
- # 3. Apply Multipliers
123
- adaptive_cfg = guidance_scale * cfg_mult
124
- adaptive_strength = img2img_strength * str_mult
125
-
126
- print(f"-> CFG Multiplier: {cfg_mult:.3f} | New CFG: {adaptive_cfg:.2f}")
127
- print(f"-> Str Multiplier: {str_mult:.3f} | New Strength: {adaptive_strength:.2f}")
128
- # --- END ADAPTIVE LOGIC ---
129
-
130
  # 3. Generate Prompt
131
  if not user_prompt.strip():
132
  try:
@@ -141,10 +61,10 @@ class Generator:
141
  print(f"Prompt: {final_prompt}")
142
 
143
  # 4. Generate Control Maps
144
- print("Generating Control Maps...")
145
  depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
146
 
147
- # 5. Face vs No-Face Setup
148
  if face_info is not None:
149
  print("Face detected: Applying InstantID with keypoints.")
150
 
@@ -184,18 +104,18 @@ class Generator:
184
  image_embeds=face_emb,
185
  generator=generator,
186
 
187
- # --- Using Adaptive Values ---
188
- strength=adaptive_strength,
189
- guidance_scale=adaptive_cfg,
190
- num_inference_steps=num_inference_steps,
191
- # ---------------------------
192
 
193
  controlnet_conditioning_scale=controlnet_conditioning_scale,
194
  control_guidance_end=control_guidance_end,
195
  clip_skip=2,
196
 
197
  # --- TCD Specific Parameter ---
198
- eta=0.3, # Controls stochasticity (gamma) for TCD
199
  # ------------------------------
200
 
201
  ).images[0]
 
1
  import torch
2
  from config import Config
3
+ from utils import resize_image_to_1mp, get_caption, draw_kps
4
  from PIL import Image
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
10
+ def prepare_control_images(self, image, width, height):
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
+ Generates conditioning maps, ensuring they are resized
13
+ to the exact target dimensions (width, height).
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  print(f"Generating control maps for {width}x{height}...")
16
+
17
+ # Generate depth map
18
  depth_map_raw = self.mh.leres_detector(image)
19
+
20
+ # Generate lineart map
21
  lineart_map_raw = self.mh.lineart_anime_detector(image)
22
 
23
+ # Manually resize maps to match the exact output resolution
24
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
25
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
26
+
27
  return depth_map, lineart_map
28
 
29
  def predict(
 
38
  lineart_strength=0.3,
39
  seed=-1
40
  ):
41
+ # 1. Pre-process Inputs
42
  print("Processing Input...")
43
+ # Reverted to original aspect-ratio preserving resize
44
+ processed_image = resize_image_to_1mp(input_image)
45
  target_width, target_height = processed_image.size
46
 
47
  # 2. Get Face Info
48
  face_info = self.mh.get_face_info(processed_image)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # 3. Generate Prompt
51
  if not user_prompt.strip():
52
  try:
 
61
  print(f"Prompt: {final_prompt}")
62
 
63
  # 4. Generate Control Maps
64
+ print("Generating Control Maps (Depth, LineArt)...")
65
  depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
66
 
67
+ # 5. Logic for Face vs No-Face
68
  if face_info is not None:
69
  print("Face detected: Applying InstantID with keypoints.")
70
 
 
104
  image_embeds=face_emb,
105
  generator=generator,
106
 
107
+ # --- Static Values (Adaptive Logic Removed) ---
108
+ strength=img2img_strength,
109
+ guidance_scale=guidance_scale,
110
+ num_inference_steps=num_inference_steps,
111
+ # --------------------------------------------
112
 
113
  controlnet_conditioning_scale=controlnet_conditioning_scale,
114
  control_guidance_end=control_guidance_end,
115
  clip_skip=2,
116
 
117
  # --- TCD Specific Parameter ---
118
+ eta=0.3,
119
  # ------------------------------
120
 
121
  ).images[0]