Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from config import Config | |
| from utils import get_caption, draw_kps # Removed resize_image_to_1mp | |
| from PIL import Image | |
| class Generator: | |
| def __init__(self, model_handler): | |
| self.mh = model_handler | |
| def smart_crop_and_resize(self, image): | |
| """ | |
| Analyzes aspect ratio and snaps to the best SDXL resolution bucket. | |
| Performs a center crop to match the target ratio, then resizes. | |
| """ | |
| w, h = image.size | |
| aspect_ratio = w / h | |
| # 1. Determine Target Resolution (Horizon SDXL Buckets) | |
| if 0.85 <= aspect_ratio <= 1.15: | |
| target_w, target_h = 1024, 1024 | |
| print(f"Snap to Bucket: Square (1024x1024)") | |
| elif aspect_ratio < 0.85: | |
| if aspect_ratio < 0.72: | |
| target_w, target_h = 832, 1216 # Tall Portrait | |
| print(f"Snap to Bucket: Tall Portrait (832x1216)") | |
| else: | |
| target_w, target_h = 896, 1152 # Standard Portrait | |
| print(f"Snap to Bucket: Portrait (896x1152)") | |
| else: # aspect_ratio > 1.15 | |
| if aspect_ratio > 1.35: | |
| target_w, target_h = 1216, 832 # Wide Landscape | |
| print(f"Snap to Bucket: Wide Landscape (1216x832)") | |
| else: | |
| target_w, target_h = 1152, 896 # Standard Landscape | |
| print(f"Snap to Bucket: Landscape (1152x896)") | |
| # 2. Center Crop to Target Aspect Ratio | |
| target_ar = target_w / target_h | |
| if aspect_ratio > target_ar: | |
| new_w = int(h * target_ar) | |
| offset = (w - new_w) // 2 | |
| crop_box = (offset, 0, offset + new_w, h) | |
| else: | |
| new_h = int(w / target_ar) | |
| offset = (h - new_h) // 2 | |
| crop_box = (0, offset, w, offset + new_h) | |
| cropped_img = image.crop(crop_box) | |
| # 3. Resize to Exact Target Resolution | |
| final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS) | |
| return final_img | |
| def prepare_control_images(self, image, width, height): | |
| """ | |
| Generates conditioning maps, ensuring they are resized | |
| to the exact target dimensions (width, height). | |
| """ | |
| print(f"Generating control maps for {width}x{height}...") | |
| depth_map_raw = self.mh.leres_detector(image) | |
| lineart_map_raw = self.mh.lineart_anime_detector(image) | |
| depth_map = depth_map_raw.resize((width, height), Image.LANCZOS) | |
| lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS) | |
| return depth_map, lineart_map | |
| def predict( | |
| self, | |
| input_image, | |
| user_prompt="", | |
| negative_prompt="", | |
| # --- DPMSolver++ Optimized Defaults --- | |
| guidance_scale=7.0, | |
| num_inference_steps=20, | |
| img2img_strength=0.85, | |
| # ---------------------------- | |
| depth_strength=0.8, | |
| lineart_strength=0.8, | |
| seed=-1 | |
| ): | |
| # 1. Pre-process Inputs (Using Smart Crop) | |
| print("Processing Input...") | |
| processed_image = self.smart_crop_and_resize(input_image) | |
| target_width, target_height = processed_image.size | |
| # 2. Get Face Info | |
| face_info = self.mh.get_face_info(processed_image) | |
| # 3. Generate Prompt | |
| if not user_prompt.strip(): | |
| try: | |
| generated_caption = get_caption(processed_image) | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}" | |
| except Exception as e: | |
| print(f"Captioning failed: {e}, using default prompt.") | |
| final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image" | |
| else: | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}" | |
| print(f"Prompt: {final_prompt}") | |
| print(f"Negative Prompt: {negative_prompt}") | |
| # 4. Generate Control Maps | |
| print("Generating Control Maps (Depth, LineArt)...") | |
| depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height) | |
| # 5. Logic for Face vs No-Face | |
| if face_info is not None: | |
| print("Face detected: Applying InstantID with keypoints.") | |
| face_emb = torch.tensor( | |
| face_info['embedding'], | |
| dtype=Config.DTYPE, | |
| device=Config.DEVICE | |
| ).unsqueeze(0) | |
| face_kps = draw_kps(processed_image, face_info['kps']) | |
| controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength] | |
| self.mh.pipeline.set_ip_adapter_scale(0.8) | |
| else: | |
| print("No face detected: Disabling InstantID.") | |
| face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE) | |
| face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0)) | |
| controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] | |
| self.mh.pipeline.set_ip_adapter_scale(0.0) | |
| control_guidance_end = [0.3, 0.6, 0.6] | |
| if seed == -1 or seed is None: | |
| seed = torch.Generator().seed() | |
| generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed)) | |
| print(f"Using seed: {seed}") | |
| # 6. Run Inference | |
| print("Running pipeline...") | |
| result = self.mh.pipeline( | |
| prompt=final_prompt, | |
| negative_prompt=negative_prompt, | |
| image=processed_image, | |
| control_image=[face_kps, depth_map, lineart_map], | |
| image_embeds=face_emb, | |
| generator=generator, | |
| strength=img2img_strength, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| control_guidance_end=control_guidance_end, | |
| clip_skip=0, | |
| ).images[0] | |
| return result |