File size: 5,261 Bytes
911dcd6
 
3885620
a910636
911dcd6
 
 
 
 
ff014fd
460592a
ff014fd
 
460592a
ff014fd
3885620
 
ff014fd
3885620
 
ff014fd
3885620
 
60bf1c5
ff014fd
3885620
ff014fd
3e3e641
31c79b1
 
 
 
f389872
3885620
 
 
27381b4
3885620
 
ff014fd
31c79b1
3885620
911dcd6
3885620
60bf1c5
 
3885620
589234e
911dcd6
 
 
5a9aef6
c82ccd6
 
5a9aef6
 
3885620
911dcd6
 
 
 
cb173bd
911dcd6
3885620
ff014fd
 
911dcd6
3885620
 
62e516c
589234e
ff014fd
3885620
f3238f2
 
 
963056d
f3238f2
 
963056d
 
f3238f2
3885620
 
589234e
3885620
27381b4
 
911dcd6
ff014fd
3885620
a910636
3885620
589234e
0df2aa6
27381b4
3885620
 
ff014fd
589234e
 
3885620
036809c
ff014fd
3885620
069fe14
 
 
 
3885620
069fe14
ff014fd
 
911dcd6
 
f389872
3885620
ff014fd
3885620
f389872
60bf1c5
3885620
5cf276c
ff014fd
3885620
 
31c79b1
ff014fd
911dcd6
3885620
7dd1e3b
228348f
911dcd6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from config import Config
from utils import resize_image_to_1mp, get_caption, draw_kps
from PIL import Image

class Generator:
    def __init__(self, model_handler):
        self.mh = model_handler

    def prepare_control_images(self, image, width, height):
        """
        Generates conditioning maps, ensuring they are resized
        to the exact target dimensions (width, height).
        """
        print(f"Generating control maps for {width}x{height}...")
        
        # Generate depth map
        depth_map_raw = self.mh.leres_detector(image) 
        
        # Generate lineart map
        lineart_map_raw = self.mh.lineart_anime_detector(image)
        
        # Manually resize maps to match the exact output resolution
        depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
        lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
        
        return depth_map, lineart_map

    def predict(
        self, 
        input_image, 
        user_prompt="",
        negative_prompt="",
        guidance_scale=1.5,
        num_inference_steps=6,
        img2img_strength=0.3,
        face_strength=0.3,
        depth_strength=0.3,
        lineart_strength=0.3,
        seed=-1
    ):
        # 1. Pre-process Inputs
        print("Processing Input...")
        processed_image = resize_image_to_1mp(input_image)
        target_width, target_height = processed_image.size
        
        # 2. Get Face Info (replaces get_face_embedding)
        face_info = self.mh.get_face_info(processed_image)
        
        # 3. Generate Prompt
        if not user_prompt.strip():
            try:
                generated_caption = get_caption(processed_image)
                final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
            except Exception as e:
                print(f"Captioning failed: {e}, using default prompt.")
                final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image"
        else:
            final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
            
        print(f"Prompt: {final_prompt}")
        print(f"Negative Prompt: {negative_prompt}")

        # 4. Generate OTHER Control Maps (Structure)
        print("Generating Control Maps (Depth, LineArt)...")
        depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
        
        # 5. Logic for Face vs No-Face (NOW INCLUDES KPS)
        # ControlNet order: [InstantID_KPS, Zoe_Depth, LineArt]
        
        if face_info is not None:
            print("Face detected: Applying InstantID with keypoints.")
            
            # --- FIX APPLIED HERE ---
            # We use face_info['embedding'] (raw) instead of normed_embedding.
            # Raw embedding has higher magnitude (~20-30) required for the adapter.
            face_emb = torch.tensor(
                face_info['embedding'], 
                dtype=Config.DTYPE,
                device=Config.DEVICE
            ).unsqueeze(0)
            # --- END FIX ---

            # Create keypoint image
            face_kps = draw_kps(processed_image, face_info['kps'])
            # Set strengths (using 0.8 from file's example)
            controlnet_conditioning_scale = [face_strength, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.75)
        else:
            print("No face detected: Disabling InstantID.")
            # Create dummy embedding
            face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
            # Create dummy keypoint image (black)
            face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
            face_kps_guidance_end = 0.001
            face_strength_end = 0.001
            
            # Set strengths
            controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.0)

        # We keep the guidance_end for pose low
        control_guidance_end = [face_strength * 0.4, depth_strength * 0.8, lineart_strength * 0.6] 

        # --- Seed/Generator Logic ---
        if seed == -1 or seed is None:
            seed = torch.Generator().seed()
        generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
        print(f"Using seed: {seed}")
        # --- END ---

        # 6. Run Inference
        print("Running pipeline...")
        result = self.mh.pipeline(
            prompt=final_prompt,
            negative_prompt=negative_prompt,
            image=processed_image,  # Base img2img image
            control_image=[face_kps, depth_map, lineart_map],
            image_embeds=face_emb,  # Face identity embedding
            generator=generator,
            
            # --- Parameters from UI ---
            strength=img2img_strength,
            num_inference_steps=num_inference_steps, 
            guidance_scale=guidance_scale,
            # --- End Parameters from UI ---
            
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            control_guidance_end=control_guidance_end,
            
            clip_skip=1,
            
        ).images[0]
        
        return result