File size: 6,205 Bytes
6026000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from config import Config
from utils import get_caption, draw_kps # Removed resize_image_to_1mp
from PIL import Image

class Generator:
    def __init__(self, model_handler):
        self.mh = model_handler

    def smart_crop_and_resize(self, image):
        """
        Analyzes aspect ratio and snaps to the best SDXL resolution bucket.
        Performs a center crop to match the target ratio, then resizes.
        """
        w, h = image.size
        aspect_ratio = w / h
        
        # 1. Determine Target Resolution (Horizon SDXL Buckets)
        if 0.85 <= aspect_ratio <= 1.15:
            target_w, target_h = 1024, 1024
            print(f"Snap to Bucket: Square (1024x1024)")
        elif aspect_ratio < 0.85:
            if aspect_ratio < 0.72:
                target_w, target_h = 832, 1216 # Tall Portrait
                print(f"Snap to Bucket: Tall Portrait (832x1216)")
            else:
                target_w, target_h = 896, 1152 # Standard Portrait
                print(f"Snap to Bucket: Portrait (896x1152)")
        else: # aspect_ratio > 1.15
            if aspect_ratio > 1.35:
                target_w, target_h = 1216, 832 # Wide Landscape
                print(f"Snap to Bucket: Wide Landscape (1216x832)")
            else:
                target_w, target_h = 1152, 896 # Standard Landscape
                print(f"Snap to Bucket: Landscape (1152x896)")
                
        # 2. Center Crop to Target Aspect Ratio
        target_ar = target_w / target_h
        
        if aspect_ratio > target_ar:
            new_w = int(h * target_ar)
            offset = (w - new_w) // 2
            crop_box = (offset, 0, offset + new_w, h)
        else:
            new_h = int(w / target_ar)
            offset = (h - new_h) // 2
            crop_box = (0, offset, w, offset + new_h)
            
        cropped_img = image.crop(crop_box)
        
        # 3. Resize to Exact Target Resolution
        final_img = cropped_img.resize((target_w, target_h), Image.LANCZOS)
        return final_img

    def prepare_control_images(self, image, width, height):
        """
        Generates conditioning maps, ensuring they are resized
        to the exact target dimensions (width, height).
        """
        print(f"Generating control maps for {width}x{height}...")
        depth_map_raw = self.mh.leres_detector(image) 
        lineart_map_raw = self.mh.lineart_anime_detector(image)
        depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
        lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
        return depth_map, lineart_map

    def predict(
        self, 
        input_image, 
        user_prompt="",
        negative_prompt="",
        # --- TCD Optimized Defaults ---
        guidance_scale=4.0, # <-- FIX: Set to non-zero default
        num_inference_steps=8, 
        img2img_strength=0.9,
        # ----------------------------
        depth_strength=0.3,
        lineart_strength=0.3,
        seed=-1
    ):
        # 1. Pre-process Inputs (Using Smart Crop)
        print("Processing Input...")
        processed_image = self.smart_crop_and_resize(input_image)
        target_width, target_height = processed_image.size
        
        # 2. Get Face Info
        face_info = self.mh.get_face_info(processed_image)
        
        # 3. Generate Prompt
        if not user_prompt.strip():
            try:
                generated_caption = get_caption(processed_image)
                final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
            except Exception as e:
                print(f"Captioning failed: {e}, using default prompt.")
                final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful image"
        else:
            final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
            
        print(f"Prompt: {final_prompt}")
        print(f"Negative Prompt: {negative_prompt}")

        # 4. Generate Control Maps
        print("Generating Control Maps (Depth, LineArt)...")
        depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
        
        # 5. Logic for Face vs No-Face
        if face_info is not None:
            print("Face detected: Applying InstantID with keypoints.")
            face_emb = torch.tensor(
                face_info['embedding'], 
                dtype=Config.DTYPE,
                device=Config.DEVICE
            ).unsqueeze(0)
            face_kps = draw_kps(processed_image, face_info['kps'])
            controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.8)
        else:
            print("No face detected: Disabling InstantID.")
            face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
            face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
            controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.0)

        control_guidance_end = [0.3, 0.6, 0.6] 

        if seed == -1 or seed is None:
            seed = torch.Generator().seed()
        generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
        print(f"Using seed: {seed}")

        # 6. Run Inference
        print("Running pipeline...")
        result = self.mh.pipeline(
            prompt=final_prompt,
            negative_prompt=negative_prompt,
            image=processed_image,
            control_image=[face_kps, depth_map, lineart_map],
            image_embeds=face_emb,
            generator=generator,
            
            strength=img2img_strength,
            guidance_scale=guidance_scale, # <-- Will use non-zero value
            num_inference_steps=num_inference_steps, 
            
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            control_guidance_end=control_guidance_end,
            clip_skip=0,
            
            # --- TCD Specific Parameter ---
            eta=0.45, # Gamma/Stochasticity
            # ------------------------------
            
        ).images[0]
        
        return result