pixel-art / generator.py
primerz's picture
Upload 6 files
6026000 verified
raw
history blame
6.21 kB
import torch
from config import Config
from utils import get_caption, draw_kps # Removed resize_image_to_1mp
from PIL import Image
class Generator:
def __init__(self, model_handler):
self.mh = model_handler
def smart_crop_and_resize(self, image):
"""
Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
Performs a center crop to match the target ratio, then resizes.
"""
w, h = image.size
aspect_ratio = w / h
# 1. Determine Target Resolution (Horizon SDXL Buckets)
if 0.85 <= aspect_ratio <= 1.15:
target_w, target_h = 1024, 1024
print(f"Snap to Bucket: Square (1024x1024)")
elif aspect_ratio < 0.85:
if aspect_ratio < 0.72:
target_w, target_h = 832, 1216 # Tall Portrait
print(f"Snap to Bucket: Tall Portrait (832x1216)")
else:
target_w, target_h = 896, 1152 # Standard Portrait
print(f"Snap to Bucket: Portrait (896x1152)")
else: # aspect_ratio > 1.15
if aspect_ratio > 1.35:
target_w, target_h = 1216, 832 # Wide Landscape
print(f"Snap to Bucket: Wide Landscape (1216x832)")
else:
target_w, target_h = 1152, 896 # Standard Landscape
print(f"Snap to Bucket: Landscape (1152x896)")
# 2. Center Crop to Target Aspect Ratio
target_ar = target_w / target_h
if aspect_ratio > target_ar:
new_w = int(h * target_ar)
offset = (w - new_w) // 2
crop_box = (offset, 0, offset + new_w, h)
else:
new_h = int(w / target_ar)
offset = (h - new_h) // 2
crop_box = (0, offset, w, offset + new_h)
cropped_img = image.crop(crop_box)
# 3. Resize to Exact Target Resolution
final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
return final_img
def prepare_control_images(self, image, width, height):
"""
Generates conditioning maps, ensuring they are resized
to the exact target dimensions (width, height).
"""
print(f"Generating control maps for {width}x{height}...")
depth_map_raw = self.mh.leres_detector(image)
lineart_map_raw = self.mh.lineart_anime_detector(image)
depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
return depth_map, lineart_map
def predict(
self,
input_image,
user_prompt="",
negative_prompt="",
# --- TCD Optimized Defaults ---
guidance_scale=4.0, # <-- FIX: Set to non-zero default
num_inference_steps=8,
img2img_strength=0.9,
# ----------------------------
depth_strength=0.3,
lineart_strength=0.3,
seed=-1
):
# 1. Pre-process Inputs (Using Smart Crop)
print("Processing Input...")
processed_image = self.smart_crop_and_resize(input_image)
target_width, target_height = processed_image.size
# 2. Get Face Info
face_info = self.mh.get_face_info(processed_image)
# 3. Generate Prompt
if not user_prompt.strip():
try:
generated_caption = get_caption(processed_image)
final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
except Exception as e:
print(f"Captioning failed: {e}, using default prompt.")
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
else:
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
print(f"Prompt: {final_prompt}")
print(f"Negative Prompt: {negative_prompt}")
# 4. Generate Control Maps
print("Generating Control Maps (Depth, LineArt)...")
depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
# 5. Logic for Face vs No-Face
if face_info is not None:
print("Face detected: Applying InstantID with keypoints.")
face_emb = torch.tensor(
face_info['embedding'],
dtype=Config.DTYPE,
device=Config.DEVICE
).unsqueeze(0)
face_kps = draw_kps(processed_image, face_info['kps'])
controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.8)
else:
print("No face detected: Disabling InstantID.")
face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.0)
control_guidance_end = [0.3, 0.6, 0.6]
if seed == -1 or seed is None:
seed = torch.Generator().seed()
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
print(f"Using seed: {seed}")
# 6. Run Inference
print("Running pipeline...")
result = self.mh.pipeline(
prompt=final_prompt,
negative_prompt=negative_prompt,
image=processed_image,
control_image=[face_kps, depth_map, lineart_map],
image_embeds=face_emb,
generator=generator,
strength=img2img_strength,
guidance_scale=guidance_scale, # <-- Will use non-zero value
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_end=control_guidance_end,
clip_skip=0,
# --- TCD Specific Parameter ---
eta=0.45, # Gamma/Stochasticity
# ------------------------------
).images[0]
return result