File size: 6,205 Bytes
6026000 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
from config import Config
from utils import get_caption, draw_kps # Removed resize_image_to_1mp
from PIL import Image
class Generator:
def __init__(self, model_handler):
self.mh = model_handler
def smart_crop_and_resize(self, image):
"""
Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
Performs a center crop to match the target ratio, then resizes.
"""
w, h = image.size
aspect_ratio = w / h
# 1. Determine Target Resolution (Horizon SDXL Buckets)
if 0.85 <= aspect_ratio <= 1.15:
target_w, target_h = 1024, 1024
print(f"Snap to Bucket: Square (1024x1024)")
elif aspect_ratio < 0.85:
if aspect_ratio < 0.72:
target_w, target_h = 832, 1216 # Tall Portrait
print(f"Snap to Bucket: Tall Portrait (832x1216)")
else:
target_w, target_h = 896, 1152 # Standard Portrait
print(f"Snap to Bucket: Portrait (896x1152)")
else: # aspect_ratio > 1.15
if aspect_ratio > 1.35:
target_w, target_h = 1216, 832 # Wide Landscape
print(f"Snap to Bucket: Wide Landscape (1216x832)")
else:
target_w, target_h = 1152, 896 # Standard Landscape
print(f"Snap to Bucket: Landscape (1152x896)")
# 2. Center Crop to Target Aspect Ratio
target_ar = target_w / target_h
if aspect_ratio > target_ar:
new_w = int(h * target_ar)
offset = (w - new_w) // 2
crop_box = (offset, 0, offset + new_w, h)
else:
new_h = int(w / target_ar)
offset = (h - new_h) // 2
crop_box = (0, offset, w, offset + new_h)
cropped_img = image.crop(crop_box)
# 3. Resize to Exact Target Resolution
final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
return final_img
def prepare_control_images(self, image, width, height):
"""
Generates conditioning maps, ensuring they are resized
to the exact target dimensions (width, height).
"""
print(f"Generating control maps for {width}x{height}...")
depth_map_raw = self.mh.leres_detector(image)
lineart_map_raw = self.mh.lineart_anime_detector(image)
depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
return depth_map, lineart_map
def predict(
self,
input_image,
user_prompt="",
negative_prompt="",
# --- TCD Optimized Defaults ---
guidance_scale=4.0, # <-- FIX: Set to non-zero default
num_inference_steps=8,
img2img_strength=0.9,
# ----------------------------
depth_strength=0.3,
lineart_strength=0.3,
seed=-1
):
# 1. Pre-process Inputs (Using Smart Crop)
print("Processing Input...")
processed_image = self.smart_crop_and_resize(input_image)
target_width, target_height = processed_image.size
# 2. Get Face Info
face_info = self.mh.get_face_info(processed_image)
# 3. Generate Prompt
if not user_prompt.strip():
try:
generated_caption = get_caption(processed_image)
final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
except Exception as e:
print(f"Captioning failed: {e}, using default prompt.")
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
else:
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
print(f"Prompt: {final_prompt}")
print(f"Negative Prompt: {negative_prompt}")
# 4. Generate Control Maps
print("Generating Control Maps (Depth, LineArt)...")
depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
# 5. Logic for Face vs No-Face
if face_info is not None:
print("Face detected: Applying InstantID with keypoints.")
face_emb = torch.tensor(
face_info['embedding'],
dtype=Config.DTYPE,
device=Config.DEVICE
).unsqueeze(0)
face_kps = draw_kps(processed_image, face_info['kps'])
controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.8)
else:
print("No face detected: Disabling InstantID.")
face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.0)
control_guidance_end = [0.3, 0.6, 0.6]
if seed == -1 or seed is None:
seed = torch.Generator().seed()
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
print(f"Using seed: {seed}")
# 6. Run Inference
print("Running pipeline...")
result = self.mh.pipeline(
prompt=final_prompt,
negative_prompt=negative_prompt,
image=processed_image,
control_image=[face_kps, depth_map, lineart_map],
image_embeds=face_emb,
generator=generator,
strength=img2img_strength,
guidance_scale=guidance_scale, # <-- Will use non-zero value
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_end=control_guidance_end,
clip_skip=0,
# --- TCD Specific Parameter ---
eta=0.45, # Gamma/Stochasticity
# ------------------------------
).images[0]
return result |