Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,004 Bytes
911dcd6 ff014fd a910636 911dcd6 6977800 ff014fd 6977800 ff014fd 6977800 ff014fd 6977800 ff014fd 6977800 ff014fd 6977800 ff014fd 6977800 ff014fd 460592a ff014fd 460592a ff014fd 60bf1c5 ff014fd 3e3e641 31c79b1 f389872 ff014fd 31c79b1 ff014fd 911dcd6 6977800 60bf1c5 970f731 589234e 911dcd6 5a9aef6 c82ccd6 5a9aef6 6977800 911dcd6 cb173bd 911dcd6 970f731 ff014fd 911dcd6 ff014fd 589234e ff014fd 963056d d319e6f 963056d 589234e ff014fd 911dcd6 ff014fd a910636 589234e ff014fd 589234e ff014fd 069fe14 ff014fd 911dcd6 f389872 970f731 ff014fd 970f731 f389872 60bf1c5 5cf276c 5c3da03 ff014fd 31c79b1 ff014fd 911dcd6 d65e7f8 228348f 911dcd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from config import Config
from utils import get_caption, draw_kps # Removed resize_image_to_1mp
from PIL import Image
class Generator:
def __init__(self, model_handler):
self.mh = model_handler
def smart_crop_and_resize(self, image):
"""
Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
Performs a center crop to match the target ratio, then resizes.
"""
w, h = image.size
aspect_ratio = w / h
# 1. Determine Target Resolution (Horizon SDXL Buckets)
if 0.85 <= aspect_ratio <= 1.15:
target_w, target_h = 1024, 1024
print(f"Snap to Bucket: Square (1024x1024)")
elif aspect_ratio < 0.85:
if aspect_ratio < 0.72:
target_w, target_h = 832, 1216 # Tall Portrait
print(f"Snap to Bucket: Tall Portrait (832x1216)")
else:
target_w, target_h = 896, 1152 # Standard Portrait
print(f"Snap to Bucket: Portrait (896x1152)")
else: # aspect_ratio > 1.15
if aspect_ratio > 1.35:
target_w, target_h = 1216, 832 # Wide Landscape
print(f"Snap to Bucket: Wide Landscape (1216x832)")
else:
target_w, target_h = 1152, 896 # Standard Landscape
print(f"Snap to Bucket: Landscape (1152x896)")
# 2. Center Crop to Target Aspect Ratio
target_ar = target_w / target_h
if aspect_ratio > target_ar:
new_w = int(h * target_ar)
offset = (w - new_w) // 2
crop_box = (offset, 0, offset + new_w, h)
else:
new_h = int(w / target_ar)
offset = (h - new_h) // 2
crop_box = (0, offset, w, offset + new_h)
cropped_img = image.crop(crop_box)
# 3. Resize to Exact Target Resolution
final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
return final_img
def prepare_control_images(self, image, width, height):
"""
Generates conditioning maps, ensuring they are resized
to the exact target dimensions (width, height).
"""
print(f"Generating control maps for {width}x{height}...")
depth_map_raw = self.mh.leres_detector(image)
lineart_map_raw = self.mh.lineart_anime_detector(image)
depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
return depth_map, lineart_map
def predict(
self,
input_image,
user_prompt="",
negative_prompt="",
# --- DPMSolver++ Optimized Defaults ---
guidance_scale=7.0,
num_inference_steps=20,
img2img_strength=0.85,
# ----------------------------
depth_strength=0.8,
lineart_strength=0.8,
seed=-1
):
# 1. Pre-process Inputs (Using Smart Crop)
print("Processing Input...")
processed_image = self.smart_crop_and_resize(input_image)
target_width, target_height = processed_image.size
# 2. Get Face Info
face_info = self.mh.get_face_info(processed_image)
# 3. Generate Prompt
if not user_prompt.strip():
try:
generated_caption = get_caption(processed_image)
final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
except Exception as e:
print(f"Captioning failed: {e}, using default prompt.")
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
else:
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
print(f"Prompt: {final_prompt}")
print(f"Negative Prompt: {negative_prompt}")
# 4. Generate Control Maps
print("Generating Control Maps (Depth, LineArt)...")
depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
# 5. Logic for Face vs No-Face
if face_info is not None:
print("Face detected: Applying InstantID with keypoints.")
face_emb = torch.tensor(
face_info['embedding'],
dtype=Config.DTYPE,
device=Config.DEVICE
).unsqueeze(0)
face_kps = draw_kps(processed_image, face_info['kps'])
controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.8)
else:
print("No face detected: Disabling InstantID.")
face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.0)
control_guidance_end = [0.3, 0.6, 0.6]
if seed == -1 or seed is None:
seed = torch.Generator().seed()
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
print(f"Using seed: {seed}")
# 6. Run Inference
print("Running pipeline...")
result = self.mh.pipeline(
prompt=final_prompt,
negative_prompt=negative_prompt,
image=processed_image,
control_image=[face_kps, depth_map, lineart_map],
image_embeds=face_emb,
generator=generator,
strength=img2img_strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_end=control_guidance_end,
clip_skip=0,
).images[0]
return result |