File size: 4,416 Bytes
911dcd6
 
5cf276c
a910636
911dcd6
 
 
 
 
5cf276c
460592a
5cf276c
 
460592a
091ba99
5cf276c
 
091ba99
5cf276c
 
9196a3a
460592a
5cf276c
60bf1c5
 
5cf276c
963056d
3e3e641
31c79b1
 
 
 
f389872
963056d
 
 
 
 
f389872
31c79b1
5cf276c
911dcd6
5cf276c
60bf1c5
 
8a3467f
589234e
911dcd6
 
 
5a9aef6
c82ccd6
 
5a9aef6
 
 
911dcd6
 
 
 
 
091ba99
5cf276c
589234e
911dcd6
5cf276c
589234e
 
963056d
460592a
963056d
d319e6f
 
963056d
 
 
589234e
 
460592a
 
911dcd6
ddd791a
a910636
589234e
 
 
 
 
460592a
911dcd6
069fe14
 
 
 
 
911dcd6
3e3e641
911dcd6
 
f389872
8a3467f
963056d
8a3467f
f389872
60bf1c5
5cf276c
 
 
31c79b1
911dcd6
 
34c0b1c
3e3e641
ab4a566
 
 
228348f
911dcd6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from config import Config
from utils import resize_image_to_1mp, get_caption, draw_kps
from PIL import Image

class Generator:
    def __init__(self, model_handler):
        self.mh = model_handler

    def prepare_control_images(self, image, width, height):
        """
        Generates conditioning maps, ensuring they are resized
        to the exact target dimensions (width, height).
        """
        print(f"Generating control maps for {width}x{height}...")
        
        # Generate depth map
        depth_map_raw = self.mh.leres_detector(image) 
        
        # Generate lineart map
        lineart_map_raw = self.mh.lineart_anime_detector(image)
        
        # Manually resize maps to match the exact output resolution
        depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
        lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
        
        return depth_map, lineart_map

    def predict(
        self, 
        input_image, 
        user_prompt="",
        negative_prompt="",
        guidance_scale=1.5,
        num_inference_steps=6,
        img2img_strength=0.3,
        depth_strength=0.3,
        lineart_strength=0.3,
        seed=-1
    ):
        # 1. Pre-process Inputs
        print("Processing Input...")
        processed_image = resize_image_to_1mp(input_image)
        target_width, target_height = processed_image.size
        
        # 2. Get Face Info
        face_info = self.mh.get_face_info(processed_image)
        
        # 3. Generate Prompt
        if not user_prompt.strip():
            try:
                generated_caption = get_caption(processed_image)
                final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
            except Exception as e:
                print(f"Captioning failed: {e}, using default prompt.")
                final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image"
        else:
            final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
            
        print(f"Prompt: {final_prompt}")

        # 4. Generate Control Maps
        print("Generating Control Maps (Depth, LineArt)...")
        depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
        
        # 5. Logic for Face vs No-Face
        if face_info is not None:
            print("Face detected: Applying InstantID with keypoints.")
            
            # Use Raw Embedding
            face_emb = torch.tensor(
                face_info['embedding'], 
                dtype=Config.DTYPE,
                device=Config.DEVICE
            ).unsqueeze(0)

            face_kps = draw_kps(processed_image, face_info['kps'])
            
            controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.8)
        else:
            print("No face detected: Disabling InstantID.")
            face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
            face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
            
            controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.0)

        control_guidance_end = [0.3, 0.6, 0.6] 

        if seed == -1 or seed is None:
            seed = torch.Generator().seed()
        generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
        print(f"Using seed: {seed}")

        # 6. Run Inference
        print("Running pipeline...")
        result = self.mh.pipeline(
            prompt=final_prompt,
            negative_prompt=negative_prompt,
            image=processed_image,
            control_image=[face_kps, depth_map, lineart_map],
            image_embeds=face_emb,
            generator=generator,
            
            strength=img2img_strength,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps, 
            
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            control_guidance_end=control_guidance_end,
            clip_skip=2,
            
            # --- FIX: Set eta to 0.0 to remove stochastic noise ---
            eta=0.0, 
            # ------------------------------------------------------
            
        ).images[0]
        
        return result