primerz commited on
Commit
6977800
·
verified ·
1 Parent(s): fa327ca

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +56 -21
generator.py CHANGED
@@ -1,29 +1,67 @@
1
  import torch
2
  from config import Config
3
- from utils import resize_image_to_1mp, get_caption, draw_kps
4
  from PIL import Image
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def prepare_control_images(self, image, width, height):
11
  """
12
  Generates conditioning maps, ensuring they are resized
13
  to the exact target dimensions (width, height).
14
  """
15
  print(f"Generating control maps for {width}x{height}...")
16
-
17
- # Generate depth map
18
  depth_map_raw = self.mh.leres_detector(image)
19
-
20
- # Generate lineart map
21
  lineart_map_raw = self.mh.lineart_anime_detector(image)
22
-
23
- # Manually resize maps to match the exact output resolution
24
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
25
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
26
-
27
  return depth_map, lineart_map
28
 
29
  def predict(
@@ -31,16 +69,18 @@ class Generator:
31
  input_image,
32
  user_prompt="",
33
  negative_prompt="",
34
- guidance_scale=0.0, # TCD Default 0.0
35
- num_inference_steps=6,
36
- img2img_strength=0.3,
 
 
37
  depth_strength=0.3,
38
  lineart_strength=0.3,
39
  seed=-1
40
  ):
41
- # 1. Pre-process Inputs
42
  print("Processing Input...")
43
- processed_image = resize_image_to_1mp(input_image)
44
  target_width, target_height = processed_image.size
45
 
46
  # 2. Get Face Info
@@ -53,7 +93,7 @@ class Generator:
53
  final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
54
  except Exception as e:
55
  print(f"Captioning failed: {e}, using default prompt.")
56
- final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image"
57
  else:
58
  final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
59
 
@@ -67,23 +107,18 @@ class Generator:
67
  # 5. Logic for Face vs No-Face
68
  if face_info is not None:
69
  print("Face detected: Applying InstantID with keypoints.")
70
-
71
- # Use Raw Embedding
72
  face_emb = torch.tensor(
73
  face_info['embedding'],
74
  dtype=Config.DTYPE,
75
  device=Config.DEVICE
76
  ).unsqueeze(0)
77
-
78
  face_kps = draw_kps(processed_image, face_info['kps'])
79
-
80
  controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
81
  self.mh.pipeline.set_ip_adapter_scale(0.8)
82
  else:
83
  print("No face detected: Disabling InstantID.")
84
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
85
  face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
86
-
87
  controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
88
  self.mh.pipeline.set_ip_adapter_scale(0.0)
89
 
@@ -105,7 +140,7 @@ class Generator:
105
  generator=generator,
106
 
107
  strength=img2img_strength,
108
- guidance_scale=guidance_scale,
109
  num_inference_steps=num_inference_steps,
110
 
111
  controlnet_conditioning_scale=controlnet_conditioning_scale,
@@ -113,7 +148,7 @@ class Generator:
113
  clip_skip=2,
114
 
115
  # --- TCD Specific Parameter ---
116
- eta=0.3,
117
  # ------------------------------
118
 
119
  ).images[0]
 
1
  import torch
2
  from config import Config
3
+ from utils import get_caption, draw_kps # Removed resize_image_to_1mp
4
  from PIL import Image
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
10
+ def smart_crop_and_resize(self, image):
11
+ """
12
+ Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
13
+ Performs a center crop to match the target ratio, then resizes.
14
+ """
15
+ w, h = image.size
16
+ aspect_ratio = w / h
17
+
18
+ # 1. Determine Target Resolution (Horizon SDXL Buckets)
19
+ if 0.85 <= aspect_ratio <= 1.15:
20
+ target_w, target_h = 1024, 1024
21
+ print(f"Snap to Bucket: Square (1024x1024)")
22
+ elif aspect_ratio < 0.85:
23
+ if aspect_ratio < 0.72:
24
+ target_w, target_h = 832, 1216 # Tall Portrait
25
+ print(f"Snap to Bucket: Tall Portrait (832x1216)")
26
+ else:
27
+ target_w, target_h = 896, 1152 # Standard Portrait
28
+ print(f"Snap to Bucket: Portrait (896x1152)")
29
+ else: # aspect_ratio > 1.15
30
+ if aspect_ratio > 1.35:
31
+ target_w, target_h = 1216, 832 # Wide Landscape
32
+ print(f"Snap to Bucket: Wide Landscape (1216x832)")
33
+ else:
34
+ target_w, target_h = 1152, 896 # Standard Landscape
35
+ print(f"Snap to Bucket: Landscape (1152x896)")
36
+
37
+ # 2. Center Crop to Target Aspect Ratio
38
+ target_ar = target_w / target_h
39
+
40
+ if aspect_ratio > target_ar:
41
+ new_w = int(h * target_ar)
42
+ offset = (w - new_w) // 2
43
+ crop_box = (offset, 0, offset + new_w, h)
44
+ else:
45
+ new_h = int(w / target_ar)
46
+ offset = (h - new_h) // 2
47
+ crop_box = (0, offset, w, offset + new_h)
48
+
49
+ cropped_img = image.crop(crop_box)
50
+
51
+ # 3. Resize to Exact Target Resolution
52
+ final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
53
+ return final_img
54
+
55
  def prepare_control_images(self, image, width, height):
56
  """
57
  Generates conditioning maps, ensuring they are resized
58
  to the exact target dimensions (width, height).
59
  """
60
  print(f"Generating control maps for {width}x{height}...")
 
 
61
  depth_map_raw = self.mh.leres_detector(image)
 
 
62
  lineart_map_raw = self.mh.lineart_anime_detector(image)
 
 
63
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
64
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
 
65
  return depth_map, lineart_map
66
 
67
  def predict(
 
69
  input_image,
70
  user_prompt="",
71
  negative_prompt="",
72
+ # --- TCD Optimized Defaults ---
73
+ guidance_scale=0.0,
74
+ num_inference_steps=8, # TCD works well at 8 steps
75
+ img2img_strength=0.9, # Needs to be high for img2img
76
+ # ----------------------------
77
  depth_strength=0.3,
78
  lineart_strength=0.3,
79
  seed=-1
80
  ):
81
+ # 1. Pre-process Inputs (Using Smart Crop)
82
  print("Processing Input...")
83
+ processed_image = self.smart_crop_and_resize(input_image)
84
  target_width, target_height = processed_image.size
85
 
86
  # 2. Get Face Info
 
93
  final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
94
  except Exception as e:
95
  print(f"Captioning failed: {e}, using default prompt.")
96
+ final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
97
  else:
98
  final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
99
 
 
107
  # 5. Logic for Face vs No-Face
108
  if face_info is not None:
109
  print("Face detected: Applying InstantID with keypoints.")
 
 
110
  face_emb = torch.tensor(
111
  face_info['embedding'],
112
  dtype=Config.DTYPE,
113
  device=Config.DEVICE
114
  ).unsqueeze(0)
 
115
  face_kps = draw_kps(processed_image, face_info['kps'])
 
116
  controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
117
  self.mh.pipeline.set_ip_adapter_scale(0.8)
118
  else:
119
  print("No face detected: Disabling InstantID.")
120
  face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
121
  face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
 
122
  controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
123
  self.mh.pipeline.set_ip_adapter_scale(0.0)
124
 
 
140
  generator=generator,
141
 
142
  strength=img2img_strength,
143
+ guidance_scale=guidance_scale, # Will be 0.0 from default
144
  num_inference_steps=num_inference_steps,
145
 
146
  controlnet_conditioning_scale=controlnet_conditioning_scale,
 
148
  clip_skip=2,
149
 
150
  # --- TCD Specific Parameter ---
151
+ eta=0.3, # Gamma/Stochasticity
152
  # ------------------------------
153
 
154
  ).images[0]