Spaces:
Running on Zero
Running on Zero
| import io | |
| import os | |
| import threading | |
| import queue | |
| import numpy as np | |
| import logging | |
| from PIL import Image | |
| from PIL.PngImagePlugin import PngInfo | |
| logger = logging.getLogger(__name__) | |
| output_directory = "./output" | |
| # Maximum number of images that will be saved in a single `save_images` call. | |
| # Higher counts are likely to indicate tiled intermediate outputs which should | |
| # not be saved as individual image files to avoid filling disk with tiles. | |
| # Can be configured at runtime via the `LD_MAX_IMAGES_PER_SAVE` environment | |
| # variable (default: 16). | |
| MAX_IMAGES_PER_SAVE = int(os.getenv("LD_MAX_IMAGES_PER_SAVE", "16")) | |
| # In-memory image buffer for API responses (avoids disk round-trip) | |
| # Maps request_filename_prefix -> list of (filename, subfolder, png_bytes) | |
| _image_bytes_buffer: dict[str, list[tuple[str, str, bytes]]] = {} | |
| _image_bytes_lock = threading.Lock() | |
| def store_image_bytes(prefix: str, filename: str, subfolder: str, data: bytes) -> None: | |
| """Store image bytes in memory for later retrieval by the API server.""" | |
| with _image_bytes_lock: | |
| _image_bytes_buffer.setdefault(prefix, []).append((filename, subfolder, data)) | |
| def pop_image_bytes(prefix: str) -> list[tuple[str, str, bytes]]: | |
| """Pop and return all stored image byte entries for a given prefix. | |
| Returns a list of (filename, subfolder, png_bytes) tuples. | |
| """ | |
| with _image_bytes_lock: | |
| return _image_bytes_buffer.pop(prefix, []) | |
| def get_output_directory() -> str: | |
| """#### Get the output directory. | |
| #### Returns: | |
| - `str`: The output directory. | |
| """ | |
| global output_directory | |
| return output_directory | |
| def get_save_image_path( | |
| filename_prefix: str, output_dir: str, image_width: int = 0, image_height: int = 0 | |
| ) -> tuple: | |
| """#### Get the save image path. | |
| #### Args: | |
| - `filename_prefix` (str): The filename prefix. | |
| - `output_dir` (str): The output directory. | |
| - `image_width` (int, optional): The image width. Defaults to 0. | |
| - `image_height` (int, optional): The image height. Defaults to 0. | |
| #### Returns: | |
| - `tuple`: The full output folder, filename, counter, subfolder, and filename prefix. | |
| """ | |
| def map_filename(filename: str) -> tuple: | |
| prefix_len = len(os.path.basename(filename_prefix)) | |
| prefix = filename[: prefix_len + 1] | |
| try: | |
| digits = int(filename[prefix_len + 1 :].split("_")[0]) | |
| except (ValueError, IndexError): | |
| digits = 0 | |
| return (digits, prefix) | |
| def compute_vars(input: str, image_width: int, image_height: int) -> str: | |
| input = input.replace("%width%", str(image_width)) | |
| input = input.replace("%height%", str(image_height)) | |
| return input | |
| filename_prefix = compute_vars(filename_prefix, image_width, image_height) | |
| subfolder = os.path.dirname(os.path.normpath(filename_prefix)) | |
| filename = os.path.basename(os.path.normpath(filename_prefix)) | |
| full_output_folder = os.path.join(output_dir, subfolder) | |
| subfolder_paths = [ | |
| os.path.join(full_output_folder, x) | |
| for x in ["Classic", "HiresFix", "Img2Img", "Adetailer", "ControlNet"] | |
| ] | |
| for path in subfolder_paths: | |
| os.makedirs(path, exist_ok=True) | |
| # Find highest counter across all subfolders | |
| counter = 1 | |
| for path in subfolder_paths: | |
| if os.path.exists(path): | |
| files = os.listdir(path) | |
| if files: | |
| numbers = [ | |
| map_filename(f)[0] | |
| for f in files | |
| if f.startswith(filename) and f.endswith(".png") | |
| ] | |
| if numbers: | |
| counter = max(max(numbers) + 1, counter) | |
| return full_output_folder, filename, counter, subfolder, filename_prefix | |
| MAX_RESOLUTION = 16384 | |
| class SaveImage: | |
| """#### Class for saving images.""" | |
| def __init__(self): | |
| """#### Initialize the SaveImage class.""" | |
| self.output_dir = get_output_directory() | |
| self.type = "output" | |
| self.prefix_append = "" | |
| self.compress_level = 4 | |
| def save_images( | |
| self, | |
| images: list, | |
| filename_prefix: str = "LD", | |
| prompt: str = None, | |
| extra_pnginfo: dict = None, | |
| store_bytes_prefix: str | None = None, | |
| ) -> dict: | |
| """#### Save images to the output directory. | |
| #### Args: | |
| - `images` (list): The list of images. | |
| - `filename_prefix` (str, optional): The filename prefix. Defaults to "LD". | |
| - `prompt` (str, optional): The prompt. Defaults to None. | |
| - `extra_pnginfo` (dict, optional): Additional PNG info. Defaults to None. | |
| - `store_bytes_prefix` (str, optional): If set, also buffer PNG bytes in memory | |
| under this key for zero-disk-IO API retrieval. | |
| #### Returns: | |
| - `dict`: The saved images information. | |
| """ | |
| filename_prefix += self.prefix_append | |
| # Safety: compute total number of images to be saved in this call, counting | |
| # batched tensors as multiple images. Abort early if count exceeds threshold. | |
| total_images = 0 | |
| for image in images: | |
| shape = getattr(image, 'shape', None) | |
| if shape is None: | |
| total_images += 1 | |
| continue | |
| try: | |
| if len(shape) >= 4: | |
| total_images += int(shape[0]) | |
| else: | |
| total_images += 1 | |
| except Exception: | |
| total_images += 1 | |
| if total_images > MAX_IMAGES_PER_SAVE: | |
| # Diagnostic: record basic info about incoming images to help trace | |
| # the source of excessive image counts (tiling issues, batched tensors) | |
| details = [] | |
| try: | |
| for idx, image in enumerate(images[:10]): | |
| try: | |
| shape = getattr(image, 'shape', None) | |
| dtype = getattr(image, 'dtype', None) | |
| tname = type(image).__name__ | |
| details.append(f"idx={idx} type={tname} shape={shape} dtype={dtype}") | |
| except Exception as e: | |
| details.append(f"idx={idx} inspect_failed: {e}") | |
| more = f" (+{max(0, len(images)-10)} more)" if len(images) > 10 else "" | |
| except Exception: | |
| details = ["failed to enumerate images"] | |
| more = "" | |
| logger.warning( | |
| "Attempting to save %d images in a single call (exceeds MAX_IMAGES_PER_SAVE=%d). " | |
| "This may indicate tiled intermediate outputs; aborting save to avoid creating many tile files. " | |
| "filename_prefix=%s store_bytes_prefix=%s Details: %s%s", | |
| total_images, | |
| MAX_IMAGES_PER_SAVE, | |
| filename_prefix, | |
| store_bytes_prefix, | |
| "; ".join(details), | |
| more, | |
| ) | |
| return {"ui": {"images": []}} | |
| full_output_folder, filename, counter, subfolder, filename_prefix = ( | |
| get_save_image_path( | |
| filename_prefix, | |
| self.output_dir, | |
| images[0].shape[-2], | |
| images[0].shape[-1], | |
| ) | |
| ) | |
| results = list() | |
| for batch_number, image in enumerate(images): | |
| # Convert tensor to numpy and handle different dimensions | |
| i = image.cpu().numpy() | |
| # Handle batched tensors (4D: [batch, channels, height, width] or [batch, height, width, channels]) | |
| if i.ndim == 4: | |
| # Process each image in the batch separately | |
| for sub_batch_idx in range(i.shape[0]): | |
| sub_image = i[sub_batch_idx] # Extract single image from batch | |
| # Convert to HWC format if in CHW format | |
| if sub_image.shape[0] in [1, 3, 4] and sub_image.shape[0] < min( | |
| sub_image.shape[1], sub_image.shape[2] | |
| ): | |
| sub_image = np.transpose(sub_image, (1, 2, 0)) # CHW -> HWC | |
| # Squeeze single channel dimension if present | |
| if sub_image.shape[-1] == 1: | |
| sub_image = sub_image.squeeze(-1) | |
| # Scale to 0-255 range | |
| sub_image_scaled = np.clip(sub_image * 255.0, 0, 255).astype( | |
| np.uint8 | |
| ) | |
| img = Image.fromarray(sub_image_scaled) | |
| # Attach PNG text metadata if provided | |
| if extra_pnginfo: | |
| metadata = PngInfo() | |
| for k, v in extra_pnginfo.items(): | |
| try: | |
| metadata.add_text(str(k), str(v)) | |
| except Exception: | |
| # Ensure metadata writing never blocks saving | |
| pass | |
| else: | |
| metadata = None | |
| filename_with_batch_num = filename.replace( | |
| "%batch_num%", str(batch_number) | |
| ) | |
| file = f"{filename_with_batch_num}_{counter:05}_.png" | |
| # Save the image to appropriate subfolder | |
| save_path = full_output_folder | |
| if filename_prefix == "LD-HF": | |
| save_path = os.path.join(full_output_folder, "HiresFix") | |
| elif filename_prefix == "LD-I2I": | |
| save_path = os.path.join(full_output_folder, "Img2Img") | |
| elif filename_prefix == "LD-CN": | |
| save_path = os.path.join(full_output_folder, "ControlNet") | |
| elif filename_prefix == "LD-head" or filename_prefix == "LD-body": | |
| save_path = os.path.join(full_output_folder, "Adetailer") | |
| else: | |
| save_path = os.path.join(full_output_folder, "Classic") | |
| img.save( | |
| os.path.join(save_path, file), | |
| pnginfo=metadata, | |
| compress_level=self.compress_level, | |
| ) | |
| # Buffer PNG bytes in memory for API responses (avoids re-read) | |
| if store_bytes_prefix: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG", pnginfo=metadata, compress_level=self.compress_level) | |
| save_rel_bytes = os.path.relpath(save_path, "./output") | |
| store_image_bytes(store_bytes_prefix, file, save_rel_bytes, buf.getvalue()) | |
| # Return the actual subfolder relative to ./output so callers can locate files | |
| save_rel = os.path.relpath(save_path, "./output") | |
| results.append( | |
| { | |
| "filename": file, | |
| "subfolder": save_rel, | |
| "requested_subfolder": subfolder, | |
| "type": self.type, | |
| } | |
| ) | |
| counter += 1 | |
| continue # Skip the rest of the loop for this batch | |
| # Handle 3D tensors (single image: [channels, height, width] or [height, width, channels]) | |
| elif i.ndim == 3: | |
| # Convert to HWC format if in CHW format | |
| if i.shape[0] in [1, 3, 4] and i.shape[0] < min(i.shape[1], i.shape[2]): | |
| i = np.transpose(i, (1, 2, 0)) # CHW -> HWC | |
| # Squeeze single channel dimension if present | |
| if i.shape[-1] == 1: | |
| i = i.squeeze(-1) | |
| # Handle 2D tensors (grayscale: [height, width]) | |
| elif i.ndim == 2: | |
| pass # Already in correct format | |
| else: | |
| raise ValueError(f"Unexpected tensor dimensions: {i.shape}") | |
| # Scale to 0-255 range and convert to PIL Image | |
| i_scaled = np.clip(i * 255.0, 0, 255).astype(np.uint8) | |
| img = Image.fromarray(i_scaled) | |
| # Attach PNG text metadata if provided | |
| if extra_pnginfo: | |
| metadata = PngInfo() | |
| for k, v in extra_pnginfo.items(): | |
| try: | |
| metadata.add_text(str(k), str(v)) | |
| except Exception: | |
| pass | |
| else: | |
| metadata = None | |
| filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) | |
| file = f"{filename_with_batch_num}_{counter:05}_.png" | |
| # Save the image to appropriate subfolder | |
| save_path = full_output_folder | |
| if filename_prefix == "LD-HF": | |
| save_path = os.path.join(full_output_folder, "HiresFix") | |
| elif filename_prefix == "LD-I2I": | |
| save_path = os.path.join(full_output_folder, "Img2Img") | |
| elif filename_prefix == "LD-CN": | |
| save_path = os.path.join(full_output_folder, "ControlNet") | |
| elif filename_prefix == "LD-Flux": | |
| save_path = os.path.join(full_output_folder, "Flux") | |
| elif filename_prefix == "LD-head" or filename_prefix == "LD-body": | |
| save_path = os.path.join(full_output_folder, "Adetailer") | |
| else: | |
| save_path = os.path.join(full_output_folder, "Classic") | |
| img.save( | |
| os.path.join(save_path, file), | |
| pnginfo=metadata, | |
| compress_level=self.compress_level, | |
| ) | |
| # Buffer PNG bytes in memory for API responses (avoids re-read) | |
| if store_bytes_prefix: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG", pnginfo=metadata, compress_level=self.compress_level) | |
| save_rel_bytes = os.path.relpath(save_path, "./output") | |
| store_image_bytes(store_bytes_prefix, file, save_rel_bytes, buf.getvalue()) | |
| # Return the actual subfolder relative to ./output so callers can locate files | |
| save_rel = os.path.relpath(save_path, "./output") | |
| results.append( | |
| { | |
| "filename": file, | |
| "subfolder": save_rel, | |
| "requested_subfolder": subfolder, | |
| "type": self.type, | |
| } | |
| ) | |
| counter += 1 | |
| return {"ui": {"images": results}} | |
| def save_images_async( | |
| self, | |
| images: list, | |
| filename_prefix: str = "LD", | |
| prompt: str = None, | |
| extra_pnginfo: dict = None, | |
| ) -> threading.Thread: | |
| """#### Save images asynchronously in a background thread. | |
| #### Returns: | |
| - `threading.Thread`: The background thread handling the save. | |
| """ | |
| # Create copies of tensors on CPU to free GPU memory immediately | |
| cpu_images = [img.detach().cpu().clone() for img in images] | |
| thread = threading.Thread( | |
| target=self.save_images, | |
| args=(cpu_images, filename_prefix, prompt, extra_pnginfo) | |
| ) | |
| thread.start() | |
| return thread | |