File size: 6,004 Bytes
911dcd6
 
ff014fd
a910636
911dcd6
 
 
 
 
6977800
 
 
 
 
 
 
 
ff014fd
6977800
 
 
 
 
ff014fd
6977800
 
ff014fd
6977800
ff014fd
6977800
ff014fd
6977800
 
ff014fd
6977800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff014fd
460592a
ff014fd
 
460592a
ff014fd
 
 
60bf1c5
ff014fd
 
3e3e641
31c79b1
 
 
 
f389872
ff014fd
 
 
 
 
 
 
 
31c79b1
ff014fd
911dcd6
6977800
60bf1c5
 
970f731
589234e
911dcd6
 
 
5a9aef6
c82ccd6
 
5a9aef6
 
6977800
911dcd6
 
 
 
cb173bd
911dcd6
970f731
ff014fd
 
911dcd6
ff014fd
589234e
ff014fd
963056d
d319e6f
 
963056d
 
589234e
ff014fd
 
911dcd6
ff014fd
a910636
589234e
ff014fd
589234e
 
ff014fd
 
069fe14
 
 
 
 
ff014fd
 
911dcd6
 
f389872
970f731
ff014fd
970f731
f389872
60bf1c5
5cf276c
5c3da03
ff014fd
31c79b1
ff014fd
911dcd6
d65e7f8
228348f
911dcd6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
from config import Config
from utils import get_caption, draw_kps # Removed resize_image_to_1mp
from PIL import Image

class Generator:
    def __init__(self, model_handler):
        self.mh = model_handler

    def smart_crop_and_resize(self, image):
        """
        Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
        Performs a center crop to match the target ratio, then resizes.
        """
        w, h = image.size
        aspect_ratio = w / h
        
        # 1. Determine Target Resolution (Horizon SDXL Buckets)
        if 0.85 <= aspect_ratio <= 1.15:
            target_w, target_h = 1024, 1024
            print(f"Snap to Bucket: Square (1024x1024)")
        elif aspect_ratio < 0.85:
            if aspect_ratio < 0.72:
                target_w, target_h = 832, 1216 # Tall Portrait
                print(f"Snap to Bucket: Tall Portrait (832x1216)")
            else:
                target_w, target_h = 896, 1152 # Standard Portrait
                print(f"Snap to Bucket: Portrait (896x1152)")
        else: # aspect_ratio > 1.15
            if aspect_ratio > 1.35:
                target_w, target_h = 1216, 832 # Wide Landscape
                print(f"Snap to Bucket: Wide Landscape (1216x832)")
            else:
                target_w, target_h = 1152, 896 # Standard Landscape
                print(f"Snap to Bucket: Landscape (1152x896)")
                
        # 2. Center Crop to Target Aspect Ratio
        target_ar = target_w / target_h
        
        if aspect_ratio > target_ar:
            new_w = int(h * target_ar)
            offset = (w - new_w) // 2
            crop_box = (offset, 0, offset + new_w, h)
        else:
            new_h = int(w / target_ar)
            offset = (h - new_h) // 2
            crop_box = (0, offset, w, offset + new_h)
            
        cropped_img = image.crop(crop_box)
        
        # 3. Resize to Exact Target Resolution
        final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
        return final_img

    def prepare_control_images(self, image, width, height):
        """
        Generates conditioning maps, ensuring they are resized
        to the exact target dimensions (width, height).
        """
        print(f"Generating control maps for {width}x{height}...")
        depth_map_raw = self.mh.leres_detector(image) 
        lineart_map_raw = self.mh.lineart_anime_detector(image)
        depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
        lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
        return depth_map, lineart_map

    def predict(
        self, 
        input_image, 
        user_prompt="",
        negative_prompt="",
        # --- DPMSolver++ Optimized Defaults ---
        guidance_scale=7.0, 
        num_inference_steps=20, 
        img2img_strength=0.85,
        # ----------------------------
        depth_strength=0.8,
        lineart_strength=0.8,
        seed=-1
    ):
        # 1. Pre-process Inputs (Using Smart Crop)
        print("Processing Input...")
        processed_image = self.smart_crop_and_resize(input_image)
        target_width, target_height = processed_image.size
        
        # 2. Get Face Info
        face_info = self.mh.get_face_info(processed_image)
        
        # 3. Generate Prompt
        if not user_prompt.strip():
            try:
                generated_caption = get_caption(processed_image)
                final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
            except Exception as e:
                print(f"Captioning failed: {e}, using default prompt.")
                final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
        else:
            final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
            
        print(f"Prompt: {final_prompt}")
        print(f"Negative Prompt: {negative_prompt}")

        # 4. Generate Control Maps
        print("Generating Control Maps (Depth, LineArt)...")
        depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
        
        # 5. Logic for Face vs No-Face
        if face_info is not None:
            print("Face detected: Applying InstantID with keypoints.")
            face_emb = torch.tensor(
                face_info['embedding'], 
                dtype=Config.DTYPE,
                device=Config.DEVICE
            ).unsqueeze(0)
            face_kps = draw_kps(processed_image, face_info['kps'])
            controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.8)
        else:
            print("No face detected: Disabling InstantID.")
            face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
            face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
            controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.0)

        control_guidance_end = [0.3, 0.6, 0.6] 

        if seed == -1 or seed is None:
            seed = torch.Generator().seed()
        generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
        print(f"Using seed: {seed}")

        # 6. Run Inference
        print("Running pipeline...")
        result = self.mh.pipeline(
            prompt=final_prompt,
            negative_prompt=negative_prompt,
            image=processed_image,
            control_image=[face_kps, depth_map, lineart_map],
            image_embeds=face_emb,
            generator=generator,
            
            strength=img2img_strength,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps, 
            
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            control_guidance_end=control_guidance_end,
            clip_skip=0,
            
        ).images[0]
        
        return result