primerz commited on
Commit
0f36a29
·
verified ·
1 Parent(s): 364071a

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +10 -124
generator.py CHANGED
@@ -1,156 +1,42 @@
1
  import torch
2
  from config import Config
3
- from utils import get_caption, draw_kps # Removed resize_image_to_1mp
4
- from PIL import Image
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
10
- def smart_crop_and_resize(self, image):
11
- """
12
- Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
13
- Performs a center crop to match the target ratio, then resizes.
14
- """
15
- w, h = image.size
16
- aspect_ratio = w / h
17
-
18
- # 1. Determine Target Resolution (Horizon SDXL Buckets)
19
- if 0.85 <= aspect_ratio <= 1.15:
20
- target_w, target_h = 1024, 1024
21
- print(f"Snap to Bucket: Square (1024x1024)")
22
- elif aspect_ratio < 0.85:
23
- if aspect_ratio < 0.72:
24
- target_w, target_h = 832, 1216 # Tall Portrait
25
- print(f"Snap to Bucket: Tall Portrait (832x1216)")
26
- else:
27
- target_w, target_h = 896, 1152 # Standard Portrait
28
- print(f"Snap to Bucket: Portrait (896x1152)")
29
- else: # aspect_ratio > 1.15
30
- if aspect_ratio > 1.35:
31
- target_w, target_h = 1216, 832 # Wide Landscape
32
- print(f"Snap to Bucket: Wide Landscape (1216x832)")
33
- else:
34
- target_w, target_h = 1152, 896 # Standard Landscape
35
- print(f"Snap to Bucket: Landscape (1152x896)")
36
-
37
- # 2. Center Crop to Target Aspect Ratio
38
- target_ar = target_w / target_h
39
-
40
- if aspect_ratio > target_ar:
41
- new_w = int(h * target_ar)
42
- offset = (w - new_w) // 2
43
- crop_box = (offset, 0, offset + new_w, h)
44
- else:
45
- new_h = int(w / target_ar)
46
- offset = (h - new_h) // 2
47
- crop_box = (0, offset, w, offset + new_h)
48
-
49
- cropped_img = image.crop(crop_box)
50
-
51
- # 3. Resize to Exact Target Resolution
52
- final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
53
- return final_img
54
-
55
- def prepare_control_images(self, image, width, height):
56
- """
57
- Generates conditioning maps, ensuring they are resized
58
- to the exact target dimensions (width, height).
59
- """
60
- print(f"Generating control maps for {width}x{height}...")
61
- depth_map_raw = self.mh.leres_detector(image)
62
- lineart_map_raw = self.mh.lineart_anime_detector(image)
63
- depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
64
- lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
65
- return depth_map, lineart_map
66
-
67
  def predict(
68
  self,
69
- input_image,
70
- user_prompt="",
71
  negative_prompt="",
72
- # --- TCD Optimized Defaults ---
73
- guidance_scale=4.0, # <-- FIX: Set to non-zero default
74
- num_inference_steps=8,
75
- img2img_strength=0.9,
76
- # ----------------------------
77
- depth_strength=0.3,
78
- lineart_strength=0.3,
79
  seed=-1
80
  ):
81
- # 1. Pre-process Inputs (Using Smart Crop)
82
- print("Processing Input...")
83
- processed_image = self.smart_crop_and_resize(input_image)
84
- target_width, target_height = processed_image.size
85
-
86
- # 2. Get Face Info
87
- face_info = self.mh.get_face_info(processed_image)
88
-
89
- # 3. Generate Prompt
90
  if not user_prompt.strip():
91
- try:
92
- generated_caption = get_caption(processed_image)
93
- final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
94
- except Exception as e:
95
- print(f"Captioning failed: {e}, using default prompt.")
96
- final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
97
  else:
98
  final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
99
 
100
  print(f"Prompt: {final_prompt}")
101
- print(f"Negative Prompt: {negative_prompt}")
102
-
103
- # 4. Generate Control Maps
104
- print("Generating Control Maps (Depth, LineArt)...")
105
- depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
106
 
107
- # 5. Logic for Face vs No-Face
108
- if face_info is not None:
109
- print("Face detected: Applying InstantID with keypoints.")
110
- face_emb = torch.tensor(
111
- face_info['embedding'],
112
- dtype=Config.DTYPE,
113
- device=Config.DEVICE
114
- ).unsqueeze(0)
115
- face_kps = draw_kps(processed_image, face_info['kps'])
116
- controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
117
- self.mh.pipeline.set_ip_adapter_scale(0.8)
118
- else:
119
- print("No face detected: Disabling InstantID.")
120
- face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
121
- face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
122
- controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
123
- self.mh.pipeline.set_ip_adapter_scale(0.0)
124
-
125
- control_guidance_end = [0.3, 0.6, 0.6]
126
-
127
  if seed == -1 or seed is None:
128
  seed = torch.Generator().seed()
129
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
130
  print(f"Using seed: {seed}")
131
 
132
- # 6. Run Inference
133
  print("Running pipeline...")
134
  result = self.mh.pipeline(
135
  prompt=final_prompt,
136
  negative_prompt=negative_prompt,
137
- image=processed_image,
138
- control_image=[face_kps, depth_map, lineart_map],
139
- image_embeds=face_emb,
140
  generator=generator,
141
-
142
- strength=img2img_strength,
143
- guidance_scale=guidance_scale, # <-- Will use non-zero value
144
  num_inference_steps=num_inference_steps,
145
-
146
- controlnet_conditioning_scale=controlnet_conditioning_scale,
147
- control_guidance_end=control_guidance_end,
148
- clip_skip=0,
149
-
150
- # --- TCD Specific Parameter ---
151
- eta=0.45, # Gamma/Stochasticity
152
- # ------------------------------
153
-
154
  ).images[0]
155
 
156
  return result
 
1
  import torch
2
  from config import Config
 
 
3
 
4
  class Generator:
5
  def __init__(self, model_handler):
6
  self.mh = model_handler
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def predict(
9
  self,
10
+ user_prompt,
 
11
  negative_prompt="",
12
+ guidance_scale=1.2,
13
+ num_inference_steps=8,
 
 
 
 
 
14
  seed=-1
15
  ):
16
+ # 1. Construct Prompt
 
 
 
 
 
 
 
 
17
  if not user_prompt.strip():
18
+ # Fallback if user provides empty prompt
19
+ final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful landscape, pixel art"
 
 
 
 
20
  else:
21
  final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
22
 
23
  print(f"Prompt: {final_prompt}")
 
 
 
 
 
24
 
25
+ # 2. Handle Seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if seed == -1 or seed is None:
27
  seed = torch.Generator().seed()
28
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
29
  print(f"Using seed: {seed}")
30
 
31
+ # 3. Run Text-to-Image Inference
32
  print("Running pipeline...")
33
  result = self.mh.pipeline(
34
  prompt=final_prompt,
35
  negative_prompt=negative_prompt,
 
 
 
36
  generator=generator,
 
 
 
37
  num_inference_steps=num_inference_steps,
38
+ guidance_scale=guidance_scale,
39
+ clip_skip=2, # Optional, often helps with anime/pixel styles
 
 
 
 
 
 
 
40
  ).images[0]
41
 
42
  return result