|
|
import gc
|
|
|
import os
|
|
|
import random
|
|
|
import numpy as np
|
|
|
import json
|
|
|
import torch
|
|
|
import uuid
|
|
|
from PIL import Image, PngImagePlugin
|
|
|
from datetime import datetime
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Callable, Dict, Optional, Tuple, Any, List
|
|
|
from diffusers import (
|
|
|
DDIMScheduler,
|
|
|
DPMSolverMultistepScheduler,
|
|
|
DPMSolverSinglestepScheduler,
|
|
|
EulerAncestralDiscreteScheduler,
|
|
|
EulerDiscreteScheduler,
|
|
|
AutoencoderKL,
|
|
|
StableDiffusionXLPipeline,
|
|
|
)
|
|
|
import logging
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class StyleConfig:
|
|
|
prompt: str
|
|
|
negative_prompt: str
|
|
|
|
|
|
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
|
if randomize_seed:
|
|
|
seed = random.randint(0, MAX_SEED)
|
|
|
return seed
|
|
|
|
|
|
|
|
|
def seed_everything(seed: int) -> torch.Generator:
|
|
|
torch.manual_seed(seed)
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
np.random.seed(seed)
|
|
|
generator = torch.Generator()
|
|
|
generator.manual_seed(seed)
|
|
|
return generator
|
|
|
|
|
|
|
|
|
def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
|
|
|
if aspect_ratio == "Custom":
|
|
|
return None
|
|
|
width, height = aspect_ratio.split(" x ")
|
|
|
return int(width), int(height)
|
|
|
|
|
|
|
|
|
def aspect_ratio_handler(
|
|
|
aspect_ratio: str, custom_width: int, custom_height: int
|
|
|
) -> Tuple[int, int]:
|
|
|
if aspect_ratio == "Custom":
|
|
|
return custom_width, custom_height
|
|
|
else:
|
|
|
width, height = parse_aspect_ratio(aspect_ratio)
|
|
|
return width, height
|
|
|
|
|
|
|
|
|
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
|
|
|
scheduler_factory_map = {
|
|
|
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
|
|
scheduler_config, use_karras_sigmas=True
|
|
|
),
|
|
|
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
|
|
|
scheduler_config, use_karras_sigmas=True
|
|
|
),
|
|
|
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
|
|
|
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
|
|
|
),
|
|
|
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
|
|
|
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
|
|
|
scheduler_config
|
|
|
),
|
|
|
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
|
|
|
}
|
|
|
return scheduler_factory_map.get(name, lambda: None)()
|
|
|
|
|
|
|
|
|
def free_memory() -> None:
|
|
|
"""Free up GPU and system memory."""
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.ipc_collect()
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
def preprocess_prompt(
|
|
|
style_dict,
|
|
|
style_name: str,
|
|
|
positive: str,
|
|
|
negative: str = "",
|
|
|
add_style: bool = True,
|
|
|
) -> Tuple[str, str]:
|
|
|
p, n = style_dict.get(style_name, style_dict["(None)"])
|
|
|
|
|
|
if add_style and positive.strip():
|
|
|
formatted_positive = p.format(prompt=positive)
|
|
|
else:
|
|
|
formatted_positive = positive
|
|
|
|
|
|
combined_negative = n
|
|
|
if negative.strip():
|
|
|
if combined_negative:
|
|
|
combined_negative += ", " + negative
|
|
|
else:
|
|
|
combined_negative = negative
|
|
|
|
|
|
return formatted_positive, combined_negative
|
|
|
|
|
|
|
|
|
def common_upscale(
|
|
|
samples: torch.Tensor,
|
|
|
width: int,
|
|
|
height: int,
|
|
|
upscale_method: str,
|
|
|
) -> torch.Tensor:
|
|
|
return torch.nn.functional.interpolate(
|
|
|
samples, size=(height, width), mode=upscale_method
|
|
|
)
|
|
|
|
|
|
|
|
|
def upscale(
|
|
|
samples: torch.Tensor, upscale_method: str, scale_by: float
|
|
|
) -> torch.Tensor:
|
|
|
width = round(samples.shape[3] * scale_by)
|
|
|
height = round(samples.shape[2] * scale_by)
|
|
|
return common_upscale(samples, width, height, upscale_method)
|
|
|
|
|
|
|
|
|
def preprocess_image_dimensions(width, height):
|
|
|
if width % 8 != 0:
|
|
|
width = width - (width % 8)
|
|
|
if height % 8 != 0:
|
|
|
height = height - (height % 8)
|
|
|
return width, height
|
|
|
|
|
|
|
|
|
def save_image(image, metadata, output_dir, is_colab):
|
|
|
if is_colab:
|
|
|
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
filename = f"image_{current_time}.png"
|
|
|
else:
|
|
|
filename = str(uuid.uuid4()) + ".png"
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
filepath = os.path.join(output_dir, filename)
|
|
|
metadata_str = json.dumps(metadata)
|
|
|
info = PngImagePlugin.PngInfo()
|
|
|
info.add_text("parameters", metadata_str)
|
|
|
image.save(filepath, "PNG", pnginfo=info)
|
|
|
return filepath
|
|
|
|
|
|
|
|
|
def is_google_colab():
|
|
|
try:
|
|
|
import google.colab
|
|
|
return True
|
|
|
except:
|
|
|
return False
|
|
|
|
|
|
|
|
|
def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any:
|
|
|
"""Load the Stable Diffusion pipeline."""
|
|
|
try:
|
|
|
pipeline = (
|
|
|
StableDiffusionXLPipeline.from_single_file
|
|
|
if model_name.endswith(".safetensors")
|
|
|
else StableDiffusionXLPipeline.from_pretrained
|
|
|
)
|
|
|
|
|
|
pipe = pipeline(
|
|
|
model_name,
|
|
|
vae=vae,
|
|
|
torch_dtype=torch.float16,
|
|
|
custom_pipeline="lpw_stable_diffusion_xl",
|
|
|
use_safetensors=True,
|
|
|
add_watermarker=False
|
|
|
)
|
|
|
pipe.to(device)
|
|
|
return pipe
|
|
|
except Exception as e:
|
|
|
logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
|
|
|
raise
|
|
|
|