|
|
|
|
|
""" |
|
|
UltraPixel Multi-Stage High-Resolution Generator |
|
|
Fixed parameter control with independent GPU allocation per stage |
|
|
""" |
|
|
|
|
|
import spaces |
|
|
import os |
|
|
import torch |
|
|
import yaml |
|
|
import sys |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from typing import Tuple |
|
|
import datetime |
|
|
import random |
|
|
|
|
|
sys.path.append(os.path.abspath('./')) |
|
|
|
|
|
|
|
|
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' |
|
|
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True' |
|
|
os.environ["SAFETENSORS_FAST_GPU"] = "1" |
|
|
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
from inference.utils import * |
|
|
from train import WurstCoreB, WurstCore_t2i as WurstCoreC |
|
|
from gdf import DDPMSampler |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
dtype = torch.bfloat16 |
|
|
|
|
|
|
|
|
LATENT_DIR = "/tmp/ultrapixel_latents" |
|
|
os.makedirs(LATENT_DIR, exist_ok=True) |
|
|
|
|
|
DESCRIPTION = """ |
|
|
# π¨ UltraPixel High-Resolution Image Generator |
|
|
|
|
|
Generate ultra-high-resolution images (up to 5120Γ4096) with full parameter control. |
|
|
|
|
|
**Fixed Issues:** |
|
|
- β
CFG and timestep sliders now actually work (not hardcoded) |
|
|
- β
Memory optimized for large resolutions |
|
|
- β
Independent stage execution |
|
|
|
|
|
**Pipeline:** |
|
|
- **Stage C**: Text β Latent (with UltraPixel high-res guidance) |
|
|
- **Stage B+A**: Latent β Final ultra-high-res image |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def save_latent_to_disk(latent_tensor, latent_id, metadata=None): |
|
|
latent_path = os.path.join(LATENT_DIR, f"{latent_id}.pt") |
|
|
save_data = { |
|
|
'latent': latent_tensor.cpu(), |
|
|
'metadata': metadata or {} |
|
|
} |
|
|
torch.save(save_data, latent_path) |
|
|
|
|
|
def load_latent_from_disk(latent_id): |
|
|
latent_path = os.path.join(LATENT_DIR, f"{latent_id}.pt") |
|
|
if not os.path.exists(latent_path): |
|
|
return None, None |
|
|
data = torch.load(latent_path, map_location=device) |
|
|
if isinstance(data, dict): |
|
|
return data['latent'], data.get('metadata', {}) |
|
|
return data, {} |
|
|
|
|
|
def cleanup_old_latents(): |
|
|
if not os.path.exists(LATENT_DIR): |
|
|
return |
|
|
current_time = datetime.datetime.now() |
|
|
for filename in os.listdir(LATENT_DIR): |
|
|
if not filename.endswith('.pt'): |
|
|
continue |
|
|
filepath = os.path.join(LATENT_DIR, filename) |
|
|
file_time = datetime.datetime.fromtimestamp(os.path.getmtime(filepath)) |
|
|
if (current_time - file_time).total_seconds() > 3600: |
|
|
try: |
|
|
os.remove(filepath) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def download_models(): |
|
|
"""Download all required models""" |
|
|
model_files = [ |
|
|
'stage_a.safetensors', |
|
|
'previewer.safetensors', |
|
|
'effnet_encoder.safetensors', |
|
|
'stage_b_lite_bf16.safetensors', |
|
|
'stage_c_bf16.safetensors' |
|
|
] |
|
|
|
|
|
for filename in model_files: |
|
|
hf_hub_download( |
|
|
repo_id="stabilityai/stable-cascade", |
|
|
filename=filename, |
|
|
local_dir='models' |
|
|
) |
|
|
|
|
|
|
|
|
hf_hub_download( |
|
|
repo_id="roubaofeipi/UltraPixel", |
|
|
filename='ultrapixel_t2i.safetensors', |
|
|
local_dir='models' |
|
|
) |
|
|
|
|
|
def load_models(): |
|
|
"""Initialize all models""" |
|
|
global core, core_b, models, models_b, extras, extras_b |
|
|
|
|
|
|
|
|
with open('configs/training/t2i.yaml', 'r', encoding='utf-8') as f: |
|
|
config_c = yaml.safe_load(f) |
|
|
|
|
|
core = WurstCoreC(config_dict=config_c, device=device, training=False) |
|
|
extras = core.setup_extras_pre() |
|
|
models = core.setup_models(extras) |
|
|
models.generator.eval().requires_grad_(False) |
|
|
|
|
|
|
|
|
with open('configs/inference/stage_b_1b.yaml', 'r', encoding='utf-8') as f: |
|
|
config_b = yaml.safe_load(f) |
|
|
|
|
|
core_b = WurstCoreB(config_dict=config_b, device=device, training=False) |
|
|
extras_b = core_b.setup_extras_pre() |
|
|
models_b = core_b.setup_models(extras_b, skip_clip=True) |
|
|
models_b = WurstCoreB.Models( |
|
|
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} |
|
|
) |
|
|
models_b.generator.bfloat16().eval().requires_grad_(False) |
|
|
|
|
|
|
|
|
ultrapixel_weights = torch.load('models/ultrapixel_t2i.safetensors', map_location='cpu') |
|
|
collect_sd = {} |
|
|
for k, v in ultrapixel_weights.items(): |
|
|
collect_sd[k[7:]] = v |
|
|
|
|
|
models.train_norm.load_state_dict(collect_sd) |
|
|
models.train_norm.eval() |
|
|
|
|
|
print("β
All models loaded successfully") |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_stage_c( |
|
|
prompt: str, |
|
|
height: int, |
|
|
width: int, |
|
|
seed: int, |
|
|
cfg: float, |
|
|
timesteps: int, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
) -> Tuple[str, str]: |
|
|
""" |
|
|
Stage C: Generate high-resolution latent with UltraPixel guidance |
|
|
""" |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
|
|
|
full_prompt = prompt + ' rich detail, 4k, high quality' |
|
|
|
|
|
|
|
|
height_lr, width_lr = get_target_lr_size(height / width, std_size=32) |
|
|
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1) |
|
|
stage_c_latent_shape_lr, _ = calculate_latent_sizes(height_lr, width_lr, batch_size=1) |
|
|
|
|
|
|
|
|
extras.sampling_configs['cfg'] = cfg |
|
|
extras.sampling_configs['shift'] = 1 |
|
|
extras.sampling_configs['timesteps'] = timesteps |
|
|
extras.sampling_configs['t_start'] = 1.0 |
|
|
extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) |
|
|
|
|
|
batch = {'captions': [full_prompt]} |
|
|
|
|
|
with torch.no_grad(): |
|
|
models.generator.cuda() |
|
|
with torch.cuda.amp.autocast(dtype=dtype): |
|
|
sampled_c = generation_c( |
|
|
batch, models, extras, core, |
|
|
stage_c_latent_shape, stage_c_latent_shape_lr, device |
|
|
) |
|
|
|
|
|
models.generator.cpu() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
import uuid |
|
|
latent_id = str(uuid.uuid4()) |
|
|
metadata = { |
|
|
'prompt': full_prompt, |
|
|
'height': height, |
|
|
'width': width, |
|
|
'seed': seed |
|
|
} |
|
|
save_latent_to_disk(sampled_c, latent_id, metadata) |
|
|
|
|
|
del sampled_c |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
status = f"β
Stage C Complete | ID: {latent_id[:8]}..." |
|
|
return latent_id, status |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_stage_b( |
|
|
latent_id: str, |
|
|
cfg: float, |
|
|
timesteps: int, |
|
|
stage_a_tiled: bool, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Stage B+A: Decode latent to final ultra-high-res image |
|
|
""" |
|
|
|
|
|
if not latent_id: |
|
|
raise gr.Error("Invalid latent ID from Stage C") |
|
|
|
|
|
sampled_c, metadata = load_latent_from_disk(latent_id) |
|
|
if sampled_c is None: |
|
|
raise gr.Error("Could not load latent from Stage C") |
|
|
|
|
|
prompt = metadata.get('prompt', '') |
|
|
height = metadata.get('height', 2048) |
|
|
width = metadata.get('width', 2048) |
|
|
|
|
|
|
|
|
_, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=1) |
|
|
|
|
|
|
|
|
extras_b.sampling_configs['cfg'] = cfg |
|
|
extras_b.sampling_configs['shift'] = 1 |
|
|
extras_b.sampling_configs['timesteps'] = timesteps |
|
|
extras_b.sampling_configs['t_start'] = 1.0 |
|
|
|
|
|
batch = {'captions': [prompt]} |
|
|
|
|
|
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) |
|
|
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) |
|
|
conditions_b['effnet'] = sampled_c |
|
|
unconditions_b['effnet'] = torch.zeros_like(sampled_c) |
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(dtype=dtype): |
|
|
sampled = decode_b( |
|
|
conditions_b, unconditions_b, models_b, |
|
|
stage_b_latent_shape, extras_b, device, |
|
|
stage_a_tiled=stage_a_tiled |
|
|
) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
imgs = show_images(sampled) |
|
|
|
|
|
del sampled_c, sampled |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return imgs[0] |
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 1200px; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
latent_id = gr.State("") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="A breathtaking landscape...", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
height = gr.Slider(1536, 4096, value=2304, step=32, label="Height") |
|
|
width = gr.Slider(1536, 5120, value=4096, step=32, label="Width") |
|
|
|
|
|
seed = gr.Number(label="Seed", value=123, precision=0) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Stage C: Latent Generation") |
|
|
|
|
|
with gr.Row(): |
|
|
cfg_c = gr.Slider(3, 10, value=4, step=0.1, label="CFG Scale") |
|
|
steps_c = gr.Slider(10, 50, value=20, step=1, label="Timesteps") |
|
|
|
|
|
btn_stage_c = gr.Button("π Generate Latent (Stage C)", variant="primary", size="lg") |
|
|
status_c = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Stage B+A: Image Decoding") |
|
|
|
|
|
with gr.Row(): |
|
|
cfg_b = gr.Slider(1, 5, value=1.1, step=0.1, label="CFG Scale") |
|
|
steps_b = gr.Slider(5, 30, value=10, step=1, label="Timesteps") |
|
|
|
|
|
stage_a_tiled = gr.Checkbox(label="Use Tiled Decoding (recommended for large images)", value=False) |
|
|
|
|
|
btn_stage_b = gr.Button("π Generate Image (Stage B+A)", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_image = gr.Image(label="Output", type="pil") |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Usage |
|
|
|
|
|
1. Enter your prompt and configure resolution |
|
|
2. Click "Generate Latent" (60-90s) |
|
|
3. Click "Generate Image" (60-90s) |
|
|
|
|
|
**Recommended Settings:** |
|
|
- Stage C: CFG 4, Steps 20 |
|
|
- Stage B: CFG 1.1, Steps 10 |
|
|
- Enable tiling for resolutions >3000px |
|
|
|
|
|
**Note:** Each stage runs independently with separate GPU allocation. |
|
|
""") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a clear blue sky.", |
|
|
"A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing a vintage floral dress and standing in front of a blooming garden.", |
|
|
"A highly detailed, high-quality image of the Banff National Park in Canada. The turquoise waters of Lake Louise are surrounded by snow-capped mountains and dense pine forests.", |
|
|
"A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney and warm lights glowing from the windows.", |
|
|
], |
|
|
inputs=[prompt], |
|
|
outputs=[output_image] |
|
|
) |
|
|
|
|
|
|
|
|
btn_stage_c.click( |
|
|
fn=generate_stage_c, |
|
|
inputs=[prompt, height, width, seed, cfg_c, steps_c], |
|
|
outputs=[latent_id, status_c] |
|
|
) |
|
|
|
|
|
btn_stage_b.click( |
|
|
fn=generate_stage_b, |
|
|
inputs=[latent_id, cfg_b, steps_b, stage_a_tiled], |
|
|
outputs=[output_image] |
|
|
) |
|
|
|
|
|
demo.load(cleanup_old_latents) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
download_models() |
|
|
load_models() |
|
|
demo.queue(max_size=20).launch(show_api=False) |
|
|
|