primerz commited on
Commit
460592a
·
verified ·
1 Parent(s): e63b057

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +66 -17
generator.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from config import Config
3
- from utils import resize_image_to_1mp, get_caption, draw_kps
4
  from PIL import Image
5
 
6
  class Generator:
@@ -10,21 +10,74 @@ class Generator:
10
  def solve_bezier(self, t, p0, p1, p2, p3):
11
  """
12
  Calculates a point on a cubic Bezier curve for a given t (0 to 1).
13
- Formula: B(t) = (1-t)^3*P0 + 3*(1-t)^2*t*P1 + 3*(1-t)*t^2*P2 + t^3*P3
14
  """
15
  t = max(0.0, min(1.0, t))
16
-
17
  term0 = (1 - t)**3 * p0
18
  term1 = 3 * (1 - t)**2 * t * p1
19
  term2 = 3 * (1 - t) * t**2 * p2
20
  term3 = t**3 * p3
21
-
22
  return term0 + term1 + term2 + term3
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def prepare_control_images(self, image, width, height):
25
  print(f"Generating control maps for {width}x{height}...")
26
  depth_map_raw = self.mh.leres_detector(image)
27
  lineart_map_raw = self.mh.lineart_anime_detector(image)
 
 
28
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
29
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
30
  return depth_map, lineart_map
@@ -41,9 +94,9 @@ class Generator:
41
  lineart_strength=0.3,
42
  seed=-1
43
  ):
44
- # 1. Pre-process Inputs
45
  print("Processing Input...")
46
- processed_image = resize_image_to_1mp(input_image)
47
  target_width, target_height = processed_image.size
48
 
49
  # 2. Get Face Info
@@ -62,15 +115,11 @@ class Generator:
62
 
63
  print(f"Face Coverage: {coverage_ratio:.3f} ({int(coverage_ratio * 12)}/12)")
64
 
65
- # 2. Define Control Points (LESS AGGRESSIVE REDUCTION)
66
-
67
- # CFG CURVE:
68
- # Old P0 was 0.65 (35% drop). New P0 is 0.825 (17.5% drop).
69
- # Curve eases from 0.825 up to 1.0 smoothly.
70
  cfg_mult = self.solve_bezier(coverage_ratio, 0.825, 0.85, 0.95, 1.0)
71
 
72
- # STRENGTH CURVE:
73
- # Old P0 was 0.875 (12.5% drop). New P0 is 0.9375 (~6% drop).
74
  str_mult = self.solve_bezier(coverage_ratio, 0.9375, 0.95, 0.99, 1.0)
75
 
76
  # 3. Apply Multipliers
@@ -102,7 +151,7 @@ class Generator:
102
  if face_info is not None:
103
  print("Face detected: Applying InstantID with keypoints.")
104
 
105
- # Use Raw Embedding (Fix)
106
  face_emb = torch.tensor(
107
  face_info['embedding'],
108
  dtype=Config.DTYPE,
@@ -111,8 +160,8 @@ class Generator:
111
 
112
  face_kps = draw_kps(processed_image, face_info['kps'])
113
 
114
- controlnet_conditioning_scale = [0.777, depth_strength, lineart_strength]
115
- self.mh.pipeline.set_ip_adapter_scale(0.777)
116
  else:
117
  print("No face detected: Disabling InstantID.")
118
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
@@ -121,7 +170,7 @@ class Generator:
121
  controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
122
  self.mh.pipeline.set_ip_adapter_scale(0.0)
123
 
124
- control_guidance_end = [0.333, 0.666, 0.666]
125
 
126
  if seed == -1 or seed is None:
127
  seed = torch.Generator().seed()
 
1
  import torch
2
  from config import Config
3
+ from utils import get_caption, draw_kps
4
  from PIL import Image
5
 
6
  class Generator:
 
10
  def solve_bezier(self, t, p0, p1, p2, p3):
11
  """
12
  Calculates a point on a cubic Bezier curve for a given t (0 to 1).
 
13
  """
14
  t = max(0.0, min(1.0, t))
 
15
  term0 = (1 - t)**3 * p0
16
  term1 = 3 * (1 - t)**2 * t * p1
17
  term2 = 3 * (1 - t) * t**2 * p2
18
  term3 = t**3 * p3
 
19
  return term0 + term1 + term2 + term3
20
 
21
+ def smart_crop_and_resize(self, image):
22
+ """
23
+ Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
24
+ Performs a center crop to match the target ratio, then resizes.
25
+ """
26
+ w, h = image.size
27
+ aspect_ratio = w / h
28
+
29
+ # 1. Determine Target Resolution (Horizon SDXL Buckets)
30
+ if 0.85 <= aspect_ratio <= 1.15:
31
+ # Square-ish -> 1024x1024
32
+ target_w, target_h = 1024, 1024
33
+ print(f"Snap to Bucket: Square (1024x1024)")
34
+
35
+ elif aspect_ratio < 0.85:
36
+ # Portrait
37
+ # Decide between 896x1152 (AR ~0.77) and 832x1216 (AR ~0.68)
38
+ if aspect_ratio < 0.72:
39
+ target_w, target_h = 832, 1216 # Tall Portrait
40
+ print(f"Snap to Bucket: Tall Portrait (832x1216)")
41
+ else:
42
+ target_w, target_h = 896, 1152 # Standard Portrait
43
+ print(f"Snap to Bucket: Portrait (896x1152)")
44
+
45
+ else: # aspect_ratio > 1.15
46
+ # Landscape
47
+ # Decide between 1152x896 (AR ~1.28) and 1216x832 (AR ~1.46)
48
+ if aspect_ratio > 1.35:
49
+ target_w, target_h = 1216, 832 # Wide Landscape
50
+ print(f"Snap to Bucket: Wide Landscape (1216x832)")
51
+ else:
52
+ target_w, target_h = 1152, 896 # Standard Landscape
53
+ print(f"Snap to Bucket: Landscape (1152x896)")
54
+
55
+ # 2. Center Crop to Target Aspect Ratio
56
+ target_ar = target_w / target_h
57
+
58
+ if aspect_ratio > target_ar:
59
+ # Image is wider than target -> Crop width (cut sides)
60
+ new_w = int(h * target_ar)
61
+ offset = (w - new_w) // 2
62
+ crop_box = (offset, 0, offset + new_w, h)
63
+ else:
64
+ # Image is taller than target -> Crop height (cut top/bottom)
65
+ new_h = int(w / target_ar)
66
+ offset = (h - new_h) // 2
67
+ crop_box = (0, offset, w, offset + new_h)
68
+
69
+ cropped_img = image.crop(crop_box)
70
+
71
+ # 3. Resize to Exact Target Resolution
72
+ final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
73
+ return final_img
74
+
75
  def prepare_control_images(self, image, width, height):
76
  print(f"Generating control maps for {width}x{height}...")
77
  depth_map_raw = self.mh.leres_detector(image)
78
  lineart_map_raw = self.mh.lineart_anime_detector(image)
79
+
80
+ # Maps are resized to match the exact bucket resolution
81
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
82
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
83
  return depth_map, lineart_map
 
94
  lineart_strength=0.3,
95
  seed=-1
96
  ):
97
+ # 1. Pre-process Inputs (New Smart Crop)
98
  print("Processing Input...")
99
+ processed_image = self.smart_crop_and_resize(input_image)
100
  target_width, target_height = processed_image.size
101
 
102
  # 2. Get Face Info
 
115
 
116
  print(f"Face Coverage: {coverage_ratio:.3f} ({int(coverage_ratio * 12)}/12)")
117
 
118
+ # 2. Define Control Points (Half Less Aggressive)
119
+ # CFG: 0.825 start (17.5% reduction)
 
 
 
120
  cfg_mult = self.solve_bezier(coverage_ratio, 0.825, 0.85, 0.95, 1.0)
121
 
122
+ # Strength: 0.9375 start (6.25% reduction)
 
123
  str_mult = self.solve_bezier(coverage_ratio, 0.9375, 0.95, 0.99, 1.0)
124
 
125
  # 3. Apply Multipliers
 
151
  if face_info is not None:
152
  print("Face detected: Applying InstantID with keypoints.")
153
 
154
+ # Use Raw Embedding
155
  face_emb = torch.tensor(
156
  face_info['embedding'],
157
  dtype=Config.DTYPE,
 
160
 
161
  face_kps = draw_kps(processed_image, face_info['kps'])
162
 
163
+ controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
164
+ self.mh.pipeline.set_ip_adapter_scale(0.8)
165
  else:
166
  print("No face detected: Disabling InstantID.")
167
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
 
170
  controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
171
  self.mh.pipeline.set_ip_adapter_scale(0.0)
172
 
173
+ control_guidance_end = [0.3, 0.6, 0.6]
174
 
175
  if seed == -1 or seed is None:
176
  seed = torch.Generator().seed()