File size: 15,380 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import io
import os
import threading
import queue
import numpy as np
import logging
from PIL import Image
from PIL.PngImagePlugin import PngInfo

logger = logging.getLogger(__name__)
output_directory = "./output"

# Maximum number of images that will be saved in a single `save_images` call.
# Higher counts are likely to indicate tiled intermediate outputs which should
# not be saved as individual image files to avoid filling disk with tiles.
# Can be configured at runtime via the `LD_MAX_IMAGES_PER_SAVE` environment
# variable (default: 16).
MAX_IMAGES_PER_SAVE = int(os.getenv("LD_MAX_IMAGES_PER_SAVE", "16"))

# In-memory image buffer for API responses (avoids disk round-trip)
# Maps request_filename_prefix -> list of (filename, subfolder, png_bytes)
_image_bytes_buffer: dict[str, list[tuple[str, str, bytes]]] = {}
_image_bytes_lock = threading.Lock()


def store_image_bytes(prefix: str, filename: str, subfolder: str, data: bytes) -> None:
    """Store image bytes in memory for later retrieval by the API server."""
    with _image_bytes_lock:
        _image_bytes_buffer.setdefault(prefix, []).append((filename, subfolder, data))


def pop_image_bytes(prefix: str) -> list[tuple[str, str, bytes]]:
    """Pop and return all stored image byte entries for a given prefix.
    
    Returns a list of (filename, subfolder, png_bytes) tuples.
    """
    with _image_bytes_lock:
        return _image_bytes_buffer.pop(prefix, [])


def get_output_directory() -> str:
    """#### Get the output directory.

    #### Returns:
        - `str`: The output directory.
    """
    global output_directory
    return output_directory


def get_save_image_path(
    filename_prefix: str, output_dir: str, image_width: int = 0, image_height: int = 0
) -> tuple:
    """#### Get the save image path.

    #### Args:
        - `filename_prefix` (str): The filename prefix.
        - `output_dir` (str): The output directory.
        - `image_width` (int, optional): The image width. Defaults to 0.
        - `image_height` (int, optional): The image height. Defaults to 0.

    #### Returns:
        - `tuple`: The full output folder, filename, counter, subfolder, and filename prefix.
    """

    def map_filename(filename: str) -> tuple:
        prefix_len = len(os.path.basename(filename_prefix))
        prefix = filename[: prefix_len + 1]
        try:
            digits = int(filename[prefix_len + 1 :].split("_")[0])
        except (ValueError, IndexError):
            digits = 0
        return (digits, prefix)

    def compute_vars(input: str, image_width: int, image_height: int) -> str:
        input = input.replace("%width%", str(image_width))
        input = input.replace("%height%", str(image_height))
        return input

    filename_prefix = compute_vars(filename_prefix, image_width, image_height)

    subfolder = os.path.dirname(os.path.normpath(filename_prefix))
    filename = os.path.basename(os.path.normpath(filename_prefix))

    full_output_folder = os.path.join(output_dir, subfolder)
    subfolder_paths = [
        os.path.join(full_output_folder, x)
        for x in ["Classic", "HiresFix", "Img2Img", "Adetailer", "ControlNet"]
    ]
    for path in subfolder_paths:
        os.makedirs(path, exist_ok=True)
    # Find highest counter across all subfolders
    counter = 1
    for path in subfolder_paths:
        if os.path.exists(path):
            files = os.listdir(path)
            if files:
                numbers = [
                    map_filename(f)[0]
                    for f in files
                    if f.startswith(filename) and f.endswith(".png")
                ]
                if numbers:
                    counter = max(max(numbers) + 1, counter)

    return full_output_folder, filename, counter, subfolder, filename_prefix


MAX_RESOLUTION = 16384


class SaveImage:
    """#### Class for saving images."""

    def __init__(self):
        """#### Initialize the SaveImage class."""
        self.output_dir = get_output_directory()
        self.type = "output"
        self.prefix_append = ""
        self.compress_level = 4

    def save_images(
        self,
        images: list,
        filename_prefix: str = "LD",
        prompt: str = None,
        extra_pnginfo: dict = None,
        store_bytes_prefix: str | None = None,
    ) -> dict:
        """#### Save images to the output directory.

        #### Args:
            - `images` (list): The list of images.
            - `filename_prefix` (str, optional): The filename prefix. Defaults to "LD".
            - `prompt` (str, optional): The prompt. Defaults to None.
            - `extra_pnginfo` (dict, optional): Additional PNG info. Defaults to None.
            - `store_bytes_prefix` (str, optional): If set, also buffer PNG bytes in memory
              under this key for zero-disk-IO API retrieval.

        #### Returns:
            - `dict`: The saved images information.
        """
        filename_prefix += self.prefix_append

        # Safety: compute total number of images to be saved in this call, counting
        # batched tensors as multiple images. Abort early if count exceeds threshold.
        total_images = 0
        for image in images:
            shape = getattr(image, 'shape', None)
            if shape is None:
                total_images += 1
                continue
            try:
                if len(shape) >= 4:
                    total_images += int(shape[0])
                else:
                    total_images += 1
            except Exception:
                total_images += 1

        if total_images > MAX_IMAGES_PER_SAVE:
            # Diagnostic: record basic info about incoming images to help trace
            # the source of excessive image counts (tiling issues, batched tensors)
            details = []
            try:
                for idx, image in enumerate(images[:10]):
                    try:
                        shape = getattr(image, 'shape', None)
                        dtype = getattr(image, 'dtype', None)
                        tname = type(image).__name__
                        details.append(f"idx={idx} type={tname} shape={shape} dtype={dtype}")
                    except Exception as e:
                        details.append(f"idx={idx} inspect_failed: {e}")
                more = f" (+{max(0, len(images)-10)} more)" if len(images) > 10 else ""
            except Exception:
                details = ["failed to enumerate images"]
                more = ""

            logger.warning(
                "Attempting to save %d images in a single call (exceeds MAX_IMAGES_PER_SAVE=%d). "
                "This may indicate tiled intermediate outputs; aborting save to avoid creating many tile files. "
                "filename_prefix=%s store_bytes_prefix=%s Details: %s%s",
                total_images,
                MAX_IMAGES_PER_SAVE,
                filename_prefix,
                store_bytes_prefix,
                "; ".join(details),
                more,
            )
            return {"ui": {"images": []}}

        full_output_folder, filename, counter, subfolder, filename_prefix = (
            get_save_image_path(
                filename_prefix,
                self.output_dir,
                images[0].shape[-2],
                images[0].shape[-1],
            )
        )
        results = list()
        for batch_number, image in enumerate(images):
            # Convert tensor to numpy and handle different dimensions
            i = image.cpu().numpy()

            # Handle batched tensors (4D: [batch, channels, height, width] or [batch, height, width, channels])
            if i.ndim == 4:
                # Process each image in the batch separately
                for sub_batch_idx in range(i.shape[0]):
                    sub_image = i[sub_batch_idx]  # Extract single image from batch

                    # Convert to HWC format if in CHW format
                    if sub_image.shape[0] in [1, 3, 4] and sub_image.shape[0] < min(
                        sub_image.shape[1], sub_image.shape[2]
                    ):
                        sub_image = np.transpose(sub_image, (1, 2, 0))  # CHW -> HWC

                    # Squeeze single channel dimension if present
                    if sub_image.shape[-1] == 1:
                        sub_image = sub_image.squeeze(-1)

                    # Scale to 0-255 range
                    sub_image_scaled = np.clip(sub_image * 255.0, 0, 255).astype(
                        np.uint8
                    )

                    img = Image.fromarray(sub_image_scaled)
                    # Attach PNG text metadata if provided
                    if extra_pnginfo:
                        metadata = PngInfo()
                        for k, v in extra_pnginfo.items():
                            try:
                                metadata.add_text(str(k), str(v))
                            except Exception:
                                # Ensure metadata writing never blocks saving
                                pass
                    else:
                        metadata = None

                    filename_with_batch_num = filename.replace(
                        "%batch_num%", str(batch_number)
                    )
                    file = f"{filename_with_batch_num}_{counter:05}_.png"

                    # Save the image to appropriate subfolder
                    save_path = full_output_folder
                    if filename_prefix == "LD-HF":
                        save_path = os.path.join(full_output_folder, "HiresFix")
                    elif filename_prefix == "LD-I2I":
                        save_path = os.path.join(full_output_folder, "Img2Img")
                    elif filename_prefix == "LD-CN":
                        save_path = os.path.join(full_output_folder, "ControlNet")
                    elif filename_prefix == "LD-head" or filename_prefix == "LD-body":
                        save_path = os.path.join(full_output_folder, "Adetailer")
                    else:
                        save_path = os.path.join(full_output_folder, "Classic")

                    img.save(
                        os.path.join(save_path, file),
                        pnginfo=metadata,
                        compress_level=self.compress_level,
                    )
                    # Buffer PNG bytes in memory for API responses (avoids re-read)
                    if store_bytes_prefix:
                        buf = io.BytesIO()
                        img.save(buf, format="PNG", pnginfo=metadata, compress_level=self.compress_level)
                        save_rel_bytes = os.path.relpath(save_path, "./output")
                        store_image_bytes(store_bytes_prefix, file, save_rel_bytes, buf.getvalue())
                    # Return the actual subfolder relative to ./output so callers can locate files
                    save_rel = os.path.relpath(save_path, "./output")
                    results.append(
                        {
                            "filename": file,
                            "subfolder": save_rel,
                            "requested_subfolder": subfolder,
                            "type": self.type,
                        }
                    )
                    counter += 1
                continue  # Skip the rest of the loop for this batch

            # Handle 3D tensors (single image: [channels, height, width] or [height, width, channels])
            elif i.ndim == 3:
                # Convert to HWC format if in CHW format
                if i.shape[0] in [1, 3, 4] and i.shape[0] < min(i.shape[1], i.shape[2]):
                    i = np.transpose(i, (1, 2, 0))  # CHW -> HWC

                # Squeeze single channel dimension if present
                if i.shape[-1] == 1:
                    i = i.squeeze(-1)

            # Handle 2D tensors (grayscale: [height, width])
            elif i.ndim == 2:
                pass  # Already in correct format
            else:
                raise ValueError(f"Unexpected tensor dimensions: {i.shape}")

            # Scale to 0-255 range and convert to PIL Image
            i_scaled = np.clip(i * 255.0, 0, 255).astype(np.uint8)
            img = Image.fromarray(i_scaled)
            # Attach PNG text metadata if provided
            if extra_pnginfo:
                metadata = PngInfo()
                for k, v in extra_pnginfo.items():
                    try:
                        metadata.add_text(str(k), str(v))
                    except Exception:
                        pass
            else:
                metadata = None

            filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
            file = f"{filename_with_batch_num}_{counter:05}_.png"
            # Save the image to appropriate subfolder
            save_path = full_output_folder
            if filename_prefix == "LD-HF":
                save_path = os.path.join(full_output_folder, "HiresFix")
            elif filename_prefix == "LD-I2I":
                save_path = os.path.join(full_output_folder, "Img2Img")
            elif filename_prefix == "LD-CN":
                save_path = os.path.join(full_output_folder, "ControlNet")
            elif filename_prefix == "LD-Flux":
                save_path = os.path.join(full_output_folder, "Flux")
            elif filename_prefix == "LD-head" or filename_prefix == "LD-body":
                save_path = os.path.join(full_output_folder, "Adetailer")
            else:
                save_path = os.path.join(full_output_folder, "Classic")

            img.save(
                os.path.join(save_path, file),
                pnginfo=metadata,
                compress_level=self.compress_level,
            )
            # Buffer PNG bytes in memory for API responses (avoids re-read)
            if store_bytes_prefix:
                buf = io.BytesIO()
                img.save(buf, format="PNG", pnginfo=metadata, compress_level=self.compress_level)
                save_rel_bytes = os.path.relpath(save_path, "./output")
                store_image_bytes(store_bytes_prefix, file, save_rel_bytes, buf.getvalue())
            # Return the actual subfolder relative to ./output so callers can locate files
            save_rel = os.path.relpath(save_path, "./output")
            results.append(
                {
                    "filename": file,
                    "subfolder": save_rel,
                    "requested_subfolder": subfolder,
                    "type": self.type,
                }
            )
            counter += 1

        return {"ui": {"images": results}}

    def save_images_async(
        self,
        images: list,
        filename_prefix: str = "LD",
        prompt: str = None,
        extra_pnginfo: dict = None,
    ) -> threading.Thread:
        """#### Save images asynchronously in a background thread.
        
        #### Returns:
            - `threading.Thread`: The background thread handling the save.
        """
        # Create copies of tensors on CPU to free GPU memory immediately
        cpu_images = [img.detach().cpu().clone() for img in images]
        
        thread = threading.Thread(
            target=self.save_images,
            args=(cpu_images, filename_prefix, prompt, extra_pnginfo)
        )
        thread.start()
        return thread