|
|
import torch |
|
|
from config import Config |
|
|
from utils import get_caption, draw_kps |
|
|
from PIL import Image |
|
|
|
|
|
class Generator: |
|
|
def __init__(self, model_handler): |
|
|
self.mh = model_handler |
|
|
|
|
|
def smart_crop_and_resize(self, image): |
|
|
""" |
|
|
Analyzes aspect ratio and snaps to the best SDXL resolution bucket. |
|
|
Performs a center crop to match the target ratio, then resizes. |
|
|
""" |
|
|
w, h = image.size |
|
|
aspect_ratio = w / h |
|
|
|
|
|
|
|
|
if 0.85 <= aspect_ratio <= 1.15: |
|
|
target_w, target_h = 1024, 1024 |
|
|
print(f"Snap to Bucket: Square (1024x1024)") |
|
|
elif aspect_ratio < 0.85: |
|
|
if aspect_ratio < 0.72: |
|
|
target_w, target_h = 832, 1216 |
|
|
print(f"Snap to Bucket: Tall Portrait (832x1216)") |
|
|
else: |
|
|
target_w, target_h = 896, 1152 |
|
|
print(f"Snap to Bucket: Portrait (896x1152)") |
|
|
else: |
|
|
if aspect_ratio > 1.35: |
|
|
target_w, target_h = 1216, 832 |
|
|
print(f"Snap to Bucket: Wide Landscape (1216x832)") |
|
|
else: |
|
|
target_w, target_h = 1152, 896 |
|
|
print(f"Snap to Bucket: Landscape (1152x896)") |
|
|
|
|
|
|
|
|
target_ar = target_w / target_h |
|
|
|
|
|
if aspect_ratio > target_ar: |
|
|
new_w = int(h * target_ar) |
|
|
offset = (w - new_w) // 2 |
|
|
crop_box = (offset, 0, offset + new_w, h) |
|
|
else: |
|
|
new_h = int(w / target_ar) |
|
|
offset = (h - new_h) // 2 |
|
|
crop_box = (0, offset, w, offset + new_h) |
|
|
|
|
|
cropped_img = image.crop(crop_box) |
|
|
|
|
|
|
|
|
final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS) |
|
|
return final_img |
|
|
|
|
|
def prepare_control_images(self, image, width, height): |
|
|
""" |
|
|
Generates conditioning maps, ensuring they are resized |
|
|
to the exact target dimensions (width, height). |
|
|
""" |
|
|
print(f"Generating control maps for {width}x{height}...") |
|
|
depth_map_raw = self.mh.leres_detector(image) |
|
|
lineart_map_raw = self.mh.lineart_anime_detector(image) |
|
|
depth_map = depth_map_raw.resize((width, height), Image.LANCZOS) |
|
|
lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS) |
|
|
return depth_map, lineart_map |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
input_image, |
|
|
user_prompt="", |
|
|
negative_prompt="", |
|
|
|
|
|
guidance_scale=4.0, |
|
|
num_inference_steps=8, |
|
|
img2img_strength=0.9, |
|
|
|
|
|
depth_strength=0.3, |
|
|
lineart_strength=0.3, |
|
|
seed=-1 |
|
|
): |
|
|
|
|
|
print("Processing Input...") |
|
|
processed_image = self.smart_crop_and_resize(input_image) |
|
|
target_width, target_height = processed_image.size |
|
|
|
|
|
|
|
|
face_info = self.mh.get_face_info(processed_image) |
|
|
|
|
|
|
|
|
if not user_prompt.strip(): |
|
|
try: |
|
|
generated_caption = get_caption(processed_image) |
|
|
final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}" |
|
|
except Exception as e: |
|
|
print(f"Captioning failed: {e}, using default prompt.") |
|
|
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image" |
|
|
else: |
|
|
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}" |
|
|
|
|
|
print(f"Prompt: {final_prompt}") |
|
|
print(f"Negative Prompt: {negative_prompt}") |
|
|
|
|
|
|
|
|
print("Generating Control Maps (Depth, LineArt)...") |
|
|
depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height) |
|
|
|
|
|
|
|
|
if face_info is not None: |
|
|
print("Face detected: Applying InstantID with keypoints.") |
|
|
face_emb = torch.tensor( |
|
|
face_info['embedding'], |
|
|
dtype=Config.DTYPE, |
|
|
device=Config.DEVICE |
|
|
).unsqueeze(0) |
|
|
face_kps = draw_kps(processed_image, face_info['kps']) |
|
|
controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength] |
|
|
self.mh.pipeline.set_ip_adapter_scale(0.8) |
|
|
else: |
|
|
print("No face detected: Disabling InstantID.") |
|
|
face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE) |
|
|
face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0)) |
|
|
controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] |
|
|
self.mh.pipeline.set_ip_adapter_scale(0.0) |
|
|
|
|
|
control_guidance_end = [0.3, 0.6, 0.6] |
|
|
|
|
|
if seed == -1 or seed is None: |
|
|
seed = torch.Generator().seed() |
|
|
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed)) |
|
|
print(f"Using seed: {seed}") |
|
|
|
|
|
|
|
|
print("Running pipeline...") |
|
|
result = self.mh.pipeline( |
|
|
prompt=final_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
image=processed_image, |
|
|
control_image=[face_kps, depth_map, lineart_map], |
|
|
image_embeds=face_emb, |
|
|
generator=generator, |
|
|
|
|
|
strength=img2img_strength, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
|
|
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
|
control_guidance_end=control_guidance_end, |
|
|
clip_skip=0, |
|
|
|
|
|
|
|
|
eta=0.45, |
|
|
|
|
|
|
|
|
).images[0] |
|
|
|
|
|
return result |