Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,261 Bytes
911dcd6 3885620 a910636 911dcd6 ff014fd 460592a ff014fd 460592a ff014fd 3885620 ff014fd 3885620 ff014fd 3885620 60bf1c5 ff014fd 3885620 ff014fd 3e3e641 31c79b1 f389872 3885620 27381b4 3885620 ff014fd 31c79b1 3885620 911dcd6 3885620 60bf1c5 3885620 589234e 911dcd6 5a9aef6 c82ccd6 5a9aef6 3885620 911dcd6 cb173bd 911dcd6 3885620 ff014fd 911dcd6 3885620 62e516c 589234e ff014fd 3885620 f3238f2 963056d f3238f2 963056d f3238f2 3885620 589234e 3885620 27381b4 911dcd6 ff014fd 3885620 a910636 3885620 589234e 0df2aa6 27381b4 3885620 ff014fd 589234e 3885620 036809c ff014fd 3885620 069fe14 3885620 069fe14 ff014fd 911dcd6 f389872 3885620 ff014fd 3885620 f389872 60bf1c5 3885620 5cf276c ff014fd 3885620 31c79b1 ff014fd 911dcd6 3885620 7dd1e3b 228348f 911dcd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
from config import Config
from utils import resize_image_to_1mp, get_caption, draw_kps
from PIL import Image
class Generator:
def __init__(self, model_handler):
self.mh = model_handler
def prepare_control_images(self, image, width, height):
"""
Generates conditioning maps, ensuring they are resized
to the exact target dimensions (width, height).
"""
print(f"Generating control maps for {width}x{height}...")
# Generate depth map
depth_map_raw = self.mh.leres_detector(image)
# Generate lineart map
lineart_map_raw = self.mh.lineart_anime_detector(image)
# Manually resize maps to match the exact output resolution
depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
return depth_map, lineart_map
def predict(
self,
input_image,
user_prompt="",
negative_prompt="",
guidance_scale=1.5,
num_inference_steps=6,
img2img_strength=0.3,
face_strength=0.3,
depth_strength=0.3,
lineart_strength=0.3,
seed=-1
):
# 1. Pre-process Inputs
print("Processing Input...")
processed_image = resize_image_to_1mp(input_image)
target_width, target_height = processed_image.size
# 2. Get Face Info (replaces get_face_embedding)
face_info = self.mh.get_face_info(processed_image)
# 3. Generate Prompt
if not user_prompt.strip():
try:
generated_caption = get_caption(processed_image)
final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
except Exception as e:
print(f"Captioning failed: {e}, using default prompt.")
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image"
else:
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
print(f"Prompt: {final_prompt}")
print(f"Negative Prompt: {negative_prompt}")
# 4. Generate OTHER Control Maps (Structure)
print("Generating Control Maps (Depth, LineArt)...")
depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
# 5. Logic for Face vs No-Face (NOW INCLUDES KPS)
# ControlNet order: [InstantID_KPS, Zoe_Depth, LineArt]
if face_info is not None:
print("Face detected: Applying InstantID with keypoints.")
# --- FIX APPLIED HERE ---
# We use face_info['embedding'] (raw) instead of normed_embedding.
# Raw embedding has higher magnitude (~20-30) required for the adapter.
face_emb = torch.tensor(
face_info['embedding'],
dtype=Config.DTYPE,
device=Config.DEVICE
).unsqueeze(0)
# --- END FIX ---
# Create keypoint image
face_kps = draw_kps(processed_image, face_info['kps'])
# Set strengths (using 0.8 from file's example)
controlnet_conditioning_scale = [face_strength, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.75)
else:
print("No face detected: Disabling InstantID.")
# Create dummy embedding
face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
# Create dummy keypoint image (black)
face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
face_kps_guidance_end = 0.001
face_strength_end = 0.001
# Set strengths
controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.0)
# We keep the guidance_end for pose low
control_guidance_end = [face_strength * 0.4, depth_strength * 0.8, lineart_strength * 0.6]
# --- Seed/Generator Logic ---
if seed == -1 or seed is None:
seed = torch.Generator().seed()
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
print(f"Using seed: {seed}")
# --- END ---
# 6. Run Inference
print("Running pipeline...")
result = self.mh.pipeline(
prompt=final_prompt,
negative_prompt=negative_prompt,
image=processed_image, # Base img2img image
control_image=[face_kps, depth_map, lineart_map],
image_embeds=face_emb, # Face identity embedding
generator=generator,
# --- Parameters from UI ---
strength=img2img_strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
# --- End Parameters from UI ---
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_end=control_guidance_end,
clip_skip=1,
).images[0]
return result |