Spaces:
Sleeping
Sleeping
| import random | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from diffusers import ZImagePipeline, ZImageTransformer2DModel | |
| BASE_ID = "Tongyi-MAI/Z-Image-Turbo" | |
| CUSTOM_REPO = "MutantSparrow/Ray" | |
| CUSTOM_FILE = "Z-IMAGE-TURBO/Rayzist.v1.0.safetensors" | |
| FIXED_STEPS = 8 | |
| GUIDANCE = 1.0 | |
| pipe = None | |
| def load_pipe(): | |
| global pipe | |
| if pipe is not None: | |
| return pipe | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| BASE_ID, | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipe = ZImagePipeline.from_pretrained( | |
| BASE_ID, | |
| transformer=transformer, | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| ckpt_path = hf_hub_download(CUSTOM_REPO, CUSTOM_FILE) | |
| state = load_file(ckpt_path) | |
| missing, unexpected = pipe.transformer.load_state_dict(state, strict=False) | |
| print("Loaded custom weights.") | |
| print("Missing keys:", len(missing)) | |
| print("Unexpected keys:", len(unexpected)) | |
| pipe.set_progress_bar_config(disable=True) | |
| return pipe | |
| def generate(prompt, height, width): | |
| p = load_pipe() | |
| # Random seed every run | |
| seed = random.randint(0, 2**31 - 1) | |
| g = torch.Generator("cuda").manual_seed(seed) | |
| img = p( | |
| prompt=prompt, | |
| height=int(height), | |
| width=int(width), | |
| num_inference_steps=FIXED_STEPS, | |
| guidance_scale=GUIDANCE, | |
| generator=g, | |
| ).images[0] | |
| return img, seed | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Ray's Z-Image Turbo finetune: RAYZIST!") | |
| prompt = gr.Textbox(label="Prompt", lines=3) | |
| width = gr.Dropdown([512, 768, 1024, 1280, 1344], value=1024, label="Width") | |
| height = gr.Dropdown([512, 768, 1024, 1280, 1344], value=1024, label="Height") | |
| # Button ABOVE output | |
| btn = gr.Button("GO>") | |
| out = gr.Image(label="Your image") | |
| seed_info = gr.Markdown() | |
| def _run(prompt, height, width): | |
| img, seed = generate(prompt, height, width) | |
| return img, f"Seed: `{seed}`" | |
| btn.click(_run, [prompt, height, width], [out, seed_info]) | |
| demo.queue() | |
| demo.launch() | |