|
|
import os
|
|
|
import gc
|
|
|
import gradio as gr
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import json
|
|
|
import spaces
|
|
|
import config
|
|
|
import utils
|
|
|
import logging
|
|
|
from PIL import Image, PngImagePlugin
|
|
|
from datetime import datetime
|
|
|
from diffusers.models import AutoencoderKL
|
|
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
|
|
from config import (
|
|
|
MODEL,
|
|
|
MIN_IMAGE_SIZE,
|
|
|
MAX_IMAGE_SIZE,
|
|
|
USE_TORCH_COMPILE,
|
|
|
ENABLE_CPU_OFFLOAD,
|
|
|
OUTPUT_DIR,
|
|
|
DEFAULT_NEGATIVE_PROMPT,
|
|
|
DEFAULT_ASPECT_RATIO,
|
|
|
examples,
|
|
|
sampler_list,
|
|
|
aspect_ratios,
|
|
|
style_list,
|
|
|
)
|
|
|
import time
|
|
|
from typing import List, Dict, Tuple, Optional
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
logger.info(f"Using device: {device}")
|
|
|
|
|
|
class GenerationError(Exception):
|
|
|
"""Custom exception for generation errors"""
|
|
|
pass
|
|
|
|
|
|
def validate_prompt(prompt: str) -> str:
|
|
|
"""Validate and clean up the input prompt."""
|
|
|
if not isinstance(prompt, str):
|
|
|
raise GenerationError("Prompt must be a string")
|
|
|
try:
|
|
|
|
|
|
prompt = prompt.encode('utf-8').decode('utf-8')
|
|
|
|
|
|
prompt = prompt.replace("!,", "! ,")
|
|
|
except UnicodeError:
|
|
|
raise GenerationError("Invalid characters in prompt")
|
|
|
|
|
|
|
|
|
if not prompt or prompt.isspace():
|
|
|
raise GenerationError("Prompt cannot be empty")
|
|
|
return prompt.strip()
|
|
|
|
|
|
def validate_dimensions(width: int, height: int) -> None:
|
|
|
"""Validate image dimensions."""
|
|
|
if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE:
|
|
|
raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
|
|
|
|
|
|
if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE:
|
|
|
raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
|
|
|
|
|
|
@spaces.GPU(duration=25)
|
|
|
def generate(
|
|
|
prompt: str,
|
|
|
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
|
|
|
seed: int = 0,
|
|
|
custom_width: int = 1024,
|
|
|
custom_height: int = 1024,
|
|
|
guidance_scale: float = 6.0,
|
|
|
num_inference_steps: int = 25,
|
|
|
sampler: str = "Euler a",
|
|
|
aspect_ratio_selector: str = DEFAULT_ASPECT_RATIO,
|
|
|
style_selector: str = "(None)",
|
|
|
use_upscaler: bool = False,
|
|
|
upscaler_strength: float = 0.55,
|
|
|
upscale_by: float = 1.5,
|
|
|
add_quality_tags: bool = True,
|
|
|
progress: gr.Progress = gr.Progress(track_tqdm=True),
|
|
|
) -> Tuple[List[str], Dict]:
|
|
|
"""Generate images based on the given parameters."""
|
|
|
start_time = time.time()
|
|
|
upscaler_pipe = None
|
|
|
backup_scheduler = None
|
|
|
|
|
|
try:
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
prompt = validate_prompt(prompt)
|
|
|
if negative_prompt:
|
|
|
negative_prompt = negative_prompt.encode('utf-8').decode('utf-8')
|
|
|
|
|
|
validate_dimensions(custom_width, custom_height)
|
|
|
|
|
|
|
|
|
generator = utils.seed_everything(seed)
|
|
|
width, height = utils.aspect_ratio_handler(
|
|
|
aspect_ratio_selector,
|
|
|
custom_width,
|
|
|
custom_height,
|
|
|
)
|
|
|
|
|
|
|
|
|
if add_quality_tags:
|
|
|
prompt = "masterpiece, high score, great score, absurdres, {prompt}".format(prompt=prompt)
|
|
|
|
|
|
prompt, negative_prompt = utils.preprocess_prompt(
|
|
|
styles, style_selector, prompt, negative_prompt
|
|
|
)
|
|
|
|
|
|
width, height = utils.preprocess_image_dimensions(width, height)
|
|
|
|
|
|
|
|
|
backup_scheduler = pipe.scheduler
|
|
|
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
|
|
|
|
|
|
if use_upscaler:
|
|
|
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
|
|
|
|
|
|
|
|
metadata = {
|
|
|
"prompt": prompt,
|
|
|
"negative_prompt": negative_prompt,
|
|
|
"resolution": f"{width} x {height}",
|
|
|
"guidance_scale": guidance_scale,
|
|
|
"num_inference_steps": num_inference_steps,
|
|
|
"style_preset": style_selector,
|
|
|
"seed": seed,
|
|
|
"sampler": sampler,
|
|
|
"Model": "Animagine XL 4.0",
|
|
|
"Model hash": "e3c47aedb0",
|
|
|
}
|
|
|
|
|
|
if use_upscaler:
|
|
|
new_width = int(width * upscale_by)
|
|
|
new_height = int(height * upscale_by)
|
|
|
metadata["use_upscaler"] = {
|
|
|
"upscale_method": "nearest-exact",
|
|
|
"upscaler_strength": upscaler_strength,
|
|
|
"upscale_by": upscale_by,
|
|
|
"new_resolution": f"{new_width} x {new_height}",
|
|
|
}
|
|
|
else:
|
|
|
metadata["use_upscaler"] = None
|
|
|
|
|
|
logger.info(f"Starting generation with parameters: {json.dumps(metadata, indent=4)}")
|
|
|
|
|
|
|
|
|
if use_upscaler:
|
|
|
latents = pipe(
|
|
|
prompt=prompt,
|
|
|
negative_prompt=negative_prompt,
|
|
|
width=width,
|
|
|
height=height,
|
|
|
guidance_scale=guidance_scale,
|
|
|
num_inference_steps=num_inference_steps,
|
|
|
generator=generator,
|
|
|
output_type="latent",
|
|
|
).images
|
|
|
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
|
|
|
images = upscaler_pipe(
|
|
|
prompt=prompt,
|
|
|
negative_prompt=negative_prompt,
|
|
|
image=upscaled_latents,
|
|
|
guidance_scale=guidance_scale,
|
|
|
num_inference_steps=num_inference_steps,
|
|
|
strength=upscaler_strength,
|
|
|
generator=generator,
|
|
|
output_type="pil",
|
|
|
).images
|
|
|
else:
|
|
|
images = pipe(
|
|
|
prompt=prompt,
|
|
|
negative_prompt=negative_prompt,
|
|
|
width=width,
|
|
|
height=height,
|
|
|
guidance_scale=guidance_scale,
|
|
|
num_inference_steps=num_inference_steps,
|
|
|
generator=generator,
|
|
|
output_type="pil",
|
|
|
).images
|
|
|
|
|
|
|
|
|
if images:
|
|
|
total = len(images)
|
|
|
image_paths = []
|
|
|
for idx, image in enumerate(images, 1):
|
|
|
progress(idx/total, desc="Saving images...")
|
|
|
path = utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB)
|
|
|
image_paths.append(path)
|
|
|
logger.info(f"Image {idx}/{total} saved as {path}")
|
|
|
|
|
|
generation_time = time.time() - start_time
|
|
|
logger.info(f"Generation completed successfully in {generation_time:.2f} seconds")
|
|
|
metadata["generation_time"] = f"{generation_time:.2f}s"
|
|
|
|
|
|
return image_paths, metadata
|
|
|
|
|
|
except GenerationError as e:
|
|
|
logger.warning(f"Generation validation error: {str(e)}")
|
|
|
raise gr.Error(str(e))
|
|
|
except Exception as e:
|
|
|
logger.exception("Unexpected error during generation")
|
|
|
raise gr.Error(f"Generation failed: {str(e)}")
|
|
|
finally:
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
|
|
|
if upscaler_pipe is not None:
|
|
|
del upscaler_pipe
|
|
|
|
|
|
if backup_scheduler is not None and pipe is not None:
|
|
|
pipe.scheduler = backup_scheduler
|
|
|
|
|
|
utils.free_memory()
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
try:
|
|
|
logger.info("Loading VAE and pipeline...")
|
|
|
vae = AutoencoderKL.from_pretrained(
|
|
|
"madebyollin/sdxl-vae-fp16-fix",
|
|
|
torch_dtype=torch.float16,
|
|
|
)
|
|
|
pipe = utils.load_pipeline(MODEL, device, vae=vae)
|
|
|
logger.info("Pipeline loaded successfully on GPU!")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading VAE, falling back to default: {e}")
|
|
|
pipe = utils.load_pipeline(MODEL, device)
|
|
|
else:
|
|
|
logger.warning("CUDA not available, running on CPU")
|
|
|
pipe = None
|
|
|
|
|
|
|
|
|
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
|
|
|
|
|
with gr.Blocks(css="style.css", theme="Nymbo/Nymbo_Theme_5") as demo:
|
|
|
gr.HTML(
|
|
|
"""
|
|
|
<div class="header">
|
|
|
<div class="title">ANIM4GINE</div>
|
|
|
<div class="subtitle">Gradio demo for <a href="https://huggingface.co/CagliostroLab/Animagine-XL-4.0" target="_blank">Animagine XL 4.0</a></div>
|
|
|
</div>
|
|
|
""",
|
|
|
)
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=2):
|
|
|
with gr.Group():
|
|
|
prompt = gr.Text(
|
|
|
label="Prompt",
|
|
|
max_lines=5,
|
|
|
placeholder="Describe what you want to generate",
|
|
|
info="Enter your image generation prompt here. Be specific and descriptive for better results.",
|
|
|
)
|
|
|
negative_prompt = gr.Text(
|
|
|
label="Negative Prompt",
|
|
|
max_lines=5,
|
|
|
placeholder="Describe what you want to avoid",
|
|
|
value=DEFAULT_NEGATIVE_PROMPT,
|
|
|
info="Specify elements you don't want in the image.",
|
|
|
)
|
|
|
add_quality_tags = gr.Checkbox(
|
|
|
label="Quality Tags",
|
|
|
value=True,
|
|
|
info="Add quality-enhancing tags to your prompt automatically.",
|
|
|
)
|
|
|
with gr.Accordion(label="More Settings", open=False):
|
|
|
with gr.Group():
|
|
|
aspect_ratio_selector = gr.Radio(
|
|
|
label="Aspect Ratio",
|
|
|
choices=aspect_ratios,
|
|
|
value=DEFAULT_ASPECT_RATIO,
|
|
|
container=True,
|
|
|
info="Choose the dimensions of your image.",
|
|
|
)
|
|
|
with gr.Group(visible=False) as custom_resolution:
|
|
|
with gr.Row():
|
|
|
custom_width = gr.Slider(
|
|
|
label="Width",
|
|
|
minimum=MIN_IMAGE_SIZE,
|
|
|
maximum=MAX_IMAGE_SIZE,
|
|
|
step=8,
|
|
|
value=1024,
|
|
|
info=f"Image width (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})",
|
|
|
)
|
|
|
custom_height = gr.Slider(
|
|
|
label="Height",
|
|
|
minimum=MIN_IMAGE_SIZE,
|
|
|
maximum=MAX_IMAGE_SIZE,
|
|
|
step=8,
|
|
|
value=1024,
|
|
|
info=f"Image height (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})",
|
|
|
)
|
|
|
with gr.Group():
|
|
|
use_upscaler = gr.Checkbox(
|
|
|
label="Use Upscaler",
|
|
|
value=False,
|
|
|
info="Enable high-resolution upscaling.",
|
|
|
)
|
|
|
with gr.Row() as upscaler_row:
|
|
|
upscaler_strength = gr.Slider(
|
|
|
label="Strength",
|
|
|
minimum=0,
|
|
|
maximum=1,
|
|
|
step=0.05,
|
|
|
value=0.55,
|
|
|
visible=False,
|
|
|
info="Control how much the upscaler affects the final image.",
|
|
|
)
|
|
|
upscale_by = gr.Slider(
|
|
|
label="Upscale by",
|
|
|
minimum=1,
|
|
|
maximum=1.5,
|
|
|
step=0.1,
|
|
|
value=1.5,
|
|
|
visible=False,
|
|
|
info="Multiplier for the final image resolution.",
|
|
|
)
|
|
|
with gr.Accordion(label="Advanced Parameters", open=False):
|
|
|
with gr.Group():
|
|
|
style_selector = gr.Dropdown(
|
|
|
label="Style Preset",
|
|
|
interactive=True,
|
|
|
choices=list(styles.keys()),
|
|
|
value="(None)",
|
|
|
info="Apply a predefined style to your generation.",
|
|
|
)
|
|
|
with gr.Group():
|
|
|
sampler = gr.Dropdown(
|
|
|
label="Sampler",
|
|
|
choices=sampler_list,
|
|
|
interactive=True,
|
|
|
value="Euler a",
|
|
|
info="Different samplers can produce varying results.",
|
|
|
)
|
|
|
with gr.Group():
|
|
|
seed = gr.Slider(
|
|
|
label="Seed",
|
|
|
minimum=0,
|
|
|
maximum=utils.MAX_SEED,
|
|
|
step=1,
|
|
|
value=0,
|
|
|
info="Set a specific seed for reproducible results.",
|
|
|
)
|
|
|
randomize_seed = gr.Checkbox(
|
|
|
label="Randomize seed",
|
|
|
value=True,
|
|
|
info="Generate a new random seed for each image.",
|
|
|
)
|
|
|
with gr.Group():
|
|
|
with gr.Row():
|
|
|
guidance_scale = gr.Slider(
|
|
|
label="Guidance scale",
|
|
|
minimum=1,
|
|
|
maximum=12,
|
|
|
step=0.1,
|
|
|
value=6.0,
|
|
|
info="Higher values make the image more closely match your prompt.",
|
|
|
)
|
|
|
num_inference_steps = gr.Slider(
|
|
|
label="Number of inference steps",
|
|
|
minimum=1,
|
|
|
maximum=50,
|
|
|
step=1,
|
|
|
value=25,
|
|
|
info="More steps generally mean higher quality but slower generation.",
|
|
|
)
|
|
|
|
|
|
with gr.Column(scale=3):
|
|
|
with gr.Blocks():
|
|
|
run_button = gr.Button("Generate", variant="primary", elem_id="generate-button")
|
|
|
result = gr.Gallery(
|
|
|
label="Generated Images",
|
|
|
columns=1,
|
|
|
height='768px',
|
|
|
preview=True,
|
|
|
show_label=True,
|
|
|
)
|
|
|
with gr.Accordion(label="Generation Parameters", open=False):
|
|
|
gr_metadata = gr.JSON(
|
|
|
label="Image Metadata",
|
|
|
show_label=True,
|
|
|
)
|
|
|
gr.Examples(
|
|
|
examples=examples,
|
|
|
inputs=prompt,
|
|
|
outputs=[result, gr_metadata],
|
|
|
fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
|
|
|
cache_examples=CACHE_EXAMPLES,
|
|
|
)
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
gr.HTML(
|
|
|
"""
|
|
|
"""
|
|
|
)
|
|
|
|
|
|
use_upscaler.change(
|
|
|
fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
|
|
|
inputs=use_upscaler,
|
|
|
outputs=[upscaler_strength, upscale_by],
|
|
|
queue=False,
|
|
|
api_name=False,
|
|
|
)
|
|
|
aspect_ratio_selector.change(
|
|
|
fn=lambda x: gr.update(visible=x == "Custom"),
|
|
|
inputs=aspect_ratio_selector,
|
|
|
outputs=custom_resolution,
|
|
|
queue=False,
|
|
|
api_name=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
gr.on(
|
|
|
triggers=[
|
|
|
prompt.submit,
|
|
|
negative_prompt.submit,
|
|
|
run_button.click,
|
|
|
],
|
|
|
fn=utils.randomize_seed_fn,
|
|
|
inputs=[seed, randomize_seed],
|
|
|
outputs=seed,
|
|
|
queue=False,
|
|
|
api_name=False,
|
|
|
).then(
|
|
|
fn=lambda: gr.update(interactive=False, value="Generating..."),
|
|
|
outputs=run_button,
|
|
|
).then(
|
|
|
fn=generate,
|
|
|
inputs=[
|
|
|
prompt,
|
|
|
negative_prompt,
|
|
|
seed,
|
|
|
custom_width,
|
|
|
custom_height,
|
|
|
guidance_scale,
|
|
|
num_inference_steps,
|
|
|
sampler,
|
|
|
aspect_ratio_selector,
|
|
|
style_selector,
|
|
|
use_upscaler,
|
|
|
upscaler_strength,
|
|
|
upscale_by,
|
|
|
add_quality_tags,
|
|
|
],
|
|
|
outputs=[result, gr_metadata],
|
|
|
).then(
|
|
|
fn=lambda: gr.update(interactive=True, value="Generate"),
|
|
|
outputs=run_button,
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
demo.queue(api_open=True).launch(show_api=True, show_error=True)
|
|
|
|