Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| import sys | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| import glob | |
| # Default negative prompts | |
| NEGATIVE_PROMPT_CN = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符" | |
| NEGATIVE_PROMPT_EN = "Yellowed, green-tinted, blurry, low-resolution, low-quality image, distorted limbs, eerie appearance, ugly, AI-looking, noise, grid-like artifacts, JPEG compression artifacts, abnormal limbs, watermark, garbled text, meaningless characters" | |
| # Model paths - can be overridden via environment variables | |
| MODELS_DIR = Path(os.environ.get("ZIMAGE_MODELS_DIR", "./models")) | |
| # ============================================================================= | |
| # Model Download Functions | |
| # ============================================================================= | |
| def download_hf_models(output_dir: Path) -> dict: | |
| """ | |
| Download required models from Hugging Face using huggingface_hub. | |
| Downloads: | |
| - DiffSynth-Studio/Z-Image-i2L | |
| - Tongyi-MAI/Z-Image | |
| - DiffSynth-Studio/General-Image-Encoders | |
| - Tongyi-MAI/Z-Image-Turbo | |
| Returns dict with paths to downloaded models. | |
| """ | |
| from huggingface_hub import snapshot_download | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| models = [ | |
| { | |
| "repo_id": "DiffSynth-Studio/General-Image-Encoders", | |
| "description": "General Image Encoders (SigLIP2-G384, DINOv3-7B)", | |
| "allow_patterns": None, | |
| }, | |
| { | |
| "repo_id": "Tongyi-MAI/Z-Image-Turbo", | |
| "description": "Z-Image Turbo (text encoder, VAE, tokenizer)", | |
| "allow_patterns": [ | |
| "text_encoder/*.safetensors", | |
| "vae/*.safetensors", | |
| "tokenizer/*", | |
| ], | |
| }, | |
| { | |
| "repo_id": "Tongyi-MAI/Z-Image", | |
| "description": "Z-Image base model (transformer)", | |
| "allow_patterns": ["transformer/*.safetensors"], | |
| }, | |
| { | |
| "repo_id": "DiffSynth-Studio/Z-Image-i2L", | |
| "description": "Z-Image-i2L (Image to LoRA model)", | |
| "allow_patterns": ["*.safetensors"], | |
| }, | |
| ] | |
| downloaded_paths = {} | |
| for model in models: | |
| repo_id = model["repo_id"] | |
| local_dir = output_dir / repo_id | |
| # Check if already downloaded | |
| if local_dir.exists() and any(local_dir.rglob("*.safetensors")): | |
| print(f" ✓ {repo_id} (already downloaded)") | |
| downloaded_paths[repo_id] = local_dir | |
| continue | |
| print(f" 📥 Downloading {repo_id}...") | |
| print(f" {model['description']}") | |
| try: | |
| result_path = snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=str(local_dir), | |
| allow_patterns=model["allow_patterns"], | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| ) | |
| downloaded_paths[repo_id] = Path(result_path) | |
| print(f" ✓ {repo_id}") | |
| except Exception as e: | |
| print(f" ❌ Error downloading {repo_id}: {e}") | |
| raise | |
| return downloaded_paths | |
| def get_model_files(base_path: Path, pattern: str) -> list: | |
| """Get list of files matching a glob pattern.""" | |
| full_pattern = str(base_path / pattern) | |
| files = sorted(glob.glob(full_pattern)) | |
| return files | |
| def install_diffsynth_studio(): | |
| """Clone and install DiffSynth-Studio if not already installed.""" | |
| try: | |
| from diffsynth.pipelines.z_image import ZImagePipeline | |
| return True, "✅ DiffSynth-Studio is already installed." | |
| except ImportError: | |
| pass | |
| repo_dir = Path(__file__).parent / "DiffSynth-Studio" | |
| try: | |
| if not repo_dir.exists(): | |
| print("📥 Cloning DiffSynth-Studio repository...") | |
| subprocess.run( | |
| ["git", "clone", "https://github.com/modelscope/DiffSynth-Studio.git", str(repo_dir)], | |
| capture_output=True, | |
| text=True, | |
| check=True | |
| ) | |
| print("✅ Repository cloned successfully.") | |
| else: | |
| print("📁 DiffSynth-Studio directory already exists, pulling latest...") | |
| subprocess.run( | |
| ["git", "-C", str(repo_dir), "pull"], | |
| capture_output=True, | |
| text=True | |
| ) | |
| print("📦 Installing DiffSynth-Studio...") | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "-e", str(repo_dir)], | |
| capture_output=True, | |
| text=True, | |
| check=True | |
| ) | |
| print("✅ DiffSynth-Studio installed successfully.") | |
| sys.path.insert(0, str(repo_dir)) | |
| from diffsynth.pipelines.z_image import ZImagePipeline | |
| return True, "✅ DiffSynth-Studio installed successfully!" | |
| except subprocess.CalledProcessError as e: | |
| error_msg = f"❌ Installation failed: {e.stderr}" | |
| print(error_msg) | |
| return False, error_msg | |
| except Exception as e: | |
| error_msg = f"❌ Error during installation: {str(e)}" | |
| print(error_msg) | |
| return False, error_msg | |
| # ============================================================================= | |
| # Pipeline Initialization | |
| # ============================================================================= | |
| print("=" * 60) | |
| print(" Z-Image-i2L Gradio Demo - Initializing") | |
| print("=" * 60) | |
| print() | |
| # Step 1: Install DiffSynth-Studio | |
| print("🔍 Step 1: Checking DiffSynth-Studio installation...") | |
| success, message = install_diffsynth_studio() | |
| print(message) | |
| if not success: | |
| raise RuntimeError("Failed to install DiffSynth-Studio. Cannot continue.") | |
| # Step 2: Download HuggingFace models | |
| print() | |
| print("🔍 Step 2: Downloading models from HuggingFace...") | |
| print(f" Models directory: {MODELS_DIR.absolute()}") | |
| downloaded_paths = download_hf_models(MODELS_DIR) | |
| # Import required modules | |
| from diffsynth.pipelines.z_image import ( | |
| ZImagePipeline, ModelConfig, | |
| ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode | |
| ) | |
| from safetensors.torch import save_file, load_file | |
| # Step 3: Configure VRAM settings | |
| print() | |
| print("⚙️ Step 3: Configuring VRAM settings...") | |
| vram_config = { | |
| "offload_dtype": torch.bfloat16, | |
| "offload_device": "cuda", | |
| "onload_dtype": torch.bfloat16, | |
| "onload_device": "cuda", | |
| "preparing_dtype": torch.bfloat16, | |
| "preparing_device": "cuda", | |
| "computation_dtype": torch.bfloat16, | |
| "computation_device": "cuda", | |
| } | |
| # Step 4: Resolve local model paths | |
| print() | |
| print("📂 Step 4: Resolving model paths...") | |
| # Z-Image transformer | |
| zimage_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image" | |
| zimage_transformer_files = get_model_files(zimage_path, "transformer/*.safetensors") | |
| # Z-Image-Turbo | |
| zimage_turbo_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image-Turbo" | |
| text_encoder_files = get_model_files(zimage_turbo_path, "text_encoder/*.safetensors") | |
| vae_file = get_model_files(zimage_turbo_path, "vae/diffusion_pytorch_model.safetensors") | |
| tokenizer_path = zimage_turbo_path / "tokenizer" | |
| # General Image Encoders | |
| encoders_path = MODELS_DIR / "DiffSynth-Studio" / "General-Image-Encoders" | |
| siglip_file = get_model_files(encoders_path, "SigLIP2-G384/model.safetensors") | |
| dino_file = get_model_files(encoders_path, "DINOv3-7B/model.safetensors") | |
| # Z-Image-i2L from HuggingFace | |
| zimage_i2l_path = MODELS_DIR / "DiffSynth-Studio" / "Z-Image-i2L" | |
| zimage_i2l_file = get_model_files(zimage_i2l_path, "model.safetensors") | |
| print(f" Z-Image transformer: {len(zimage_transformer_files)} file(s)") | |
| print(f" Text encoder: {len(text_encoder_files)} file(s)") | |
| print(f" VAE: {len(vae_file)} file(s)") | |
| print(f" Tokenizer: {tokenizer_path}") | |
| print(f" SigLIP2: {len(siglip_file)} file(s)") | |
| print(f" DINOv3: {len(dino_file)} file(s)") | |
| print(f" Z-Image-i2L: {len(zimage_i2l_file)} file(s)") | |
| # Validate files | |
| missing = [] | |
| if not zimage_transformer_files: missing.append("Z-Image transformer") | |
| if not text_encoder_files: missing.append("Text encoder") | |
| if not vae_file: missing.append("VAE") | |
| if not tokenizer_path.exists(): missing.append("Tokenizer") | |
| if not siglip_file: missing.append("SigLIP2") | |
| if not dino_file: missing.append("DINOv3") | |
| if not zimage_i2l_file: missing.append("Z-Image-i2L") | |
| if missing: | |
| raise FileNotFoundError(f"Missing model files: {', '.join(missing)}") | |
| # Step 5: Load pipeline | |
| print() | |
| print("🚀 Step 5: Loading Z-Image pipeline...") | |
| print(" All models loaded from HuggingFace local paths") | |
| model_configs = [ | |
| # All models from HuggingFace - use path= for local files | |
| ModelConfig(path=zimage_transformer_files, **vram_config), | |
| ModelConfig(path=text_encoder_files), | |
| ModelConfig(path=vae_file), | |
| ModelConfig(path=siglip_file), | |
| ModelConfig(path=dino_file), | |
| ModelConfig(path=zimage_i2l_file), | |
| ] | |
| pipe = ZImagePipeline.from_pretrained( | |
| torch_dtype=torch.bfloat16, | |
| device="cuda", | |
| model_configs=model_configs, | |
| tokenizer_config=ModelConfig(path=str(tokenizer_path)), | |
| ) | |
| print() | |
| print("✅ Pipeline loaded successfully!") | |
| print("=" * 60) | |
| print() | |
| # ============================================================================= | |
| # Gradio Functions | |
| # ============================================================================= | |
| def image_to_lora(images, progress=gr.Progress()): | |
| """Convert input images to a LoRA model.""" | |
| if images is None or len(images) == 0: | |
| return None, "❌ Please upload at least one image!" | |
| try: | |
| progress(0.1, desc="Processing images...") | |
| pil_images = [] | |
| for img in images: | |
| if isinstance(img, str): | |
| pil_images.append(Image.open(img).convert("RGB")) | |
| elif isinstance(img, tuple): | |
| pil_images.append(Image.open(img[0]).convert("RGB")) | |
| else: | |
| pil_images.append(Image.fromarray(img).convert("RGB")) | |
| progress(0.3, desc="Encoding images to LoRA...") | |
| with torch.no_grad(): | |
| embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=pil_images) | |
| progress(0.7, desc="Decoding LoRA weights...") | |
| lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] | |
| progress(0.9, desc="Saving LoRA file...") | |
| temp_dir = tempfile.mkdtemp() | |
| lora_path = os.path.join(temp_dir, "generated_lora.safetensors") | |
| save_file(lora, lora_path) | |
| progress(1.0, desc="Done!") | |
| return lora_path, f"✅ LoRA generated successfully from {len(pil_images)} image(s)!" | |
| except Exception as e: | |
| return None, f"❌ Error generating LoRA: {str(e)}" | |
| def generate_image( | |
| lora_file, | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| cfg_scale, | |
| sigma_shift, | |
| num_steps, | |
| progress=gr.Progress() | |
| ): | |
| """Generate an image using the created LoRA.""" | |
| if lora_file is None: | |
| return None, "❌ Please generate or upload a LoRA file first!" | |
| try: | |
| progress(0.1, desc="Loading LoRA...") | |
| lora = load_file(lora_file) | |
| # Move LoRA tensors to CUDA with correct dtype | |
| lora = {k: v.to(device="cuda", dtype=torch.bfloat16) for k, v in lora.items()} | |
| progress(0.3, desc="Generating image...") | |
| image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| seed=int(seed), | |
| cfg_scale=cfg_scale, | |
| num_inference_steps=int(num_steps), | |
| positive_only_lora=lora, | |
| sigma_shift=sigma_shift | |
| ) | |
| progress(1.0, desc="Done!") | |
| return image, "✅ Image generated successfully!" | |
| except Exception as e: | |
| return None, f"❌ Error generating image: {str(e)}" | |
| def create_demo(): | |
| """Create the Gradio interface.""" | |
| with gr.Blocks( | |
| title="Z-Image-i2L Demo", | |
| theme=gr.themes.Soft(), | |
| css=".gradio-container { max-width: 1200px !important; margin: 0 auto}" | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🎨 Z-Image-i2L: Image to LoRA Demo | |
| > 💡 **Tip**: For best results, use 4-6 images with a consistent artistic style. | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("📸 Step 1: Image to LoRA"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_gallery = gr.Gallery( | |
| label="Upload Style Images (1-6 images)", | |
| file_types=["image"], | |
| columns=3, | |
| height=300, | |
| interactive=True | |
| ) | |
| gr.Markdown(""" | |
| **Guidelines:** | |
| - Upload 1-6 images with a consistent style | |
| - Higher quality images produce better results | |
| - Mix of subjects helps generalization | |
| """) | |
| generate_lora_btn = gr.Button("🎯 Generate LoRA", variant="primary") | |
| with gr.Column(scale=1): | |
| lora_output = gr.File( | |
| label="Generated LoRA File", | |
| file_types=[".safetensors"], | |
| interactive=False | |
| ) | |
| lora_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| with gr.TabItem("🖼️ Step 2: Generate Images"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| lora_input = gr.File( | |
| label="LoRA File (from Step 1 or upload)", | |
| file_types=[".safetensors"] | |
| ) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe what you want to generate...", | |
| value="a cat", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Negative Prompt", open=False): | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value=NEGATIVE_PROMPT_CN, | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| use_cn_neg = gr.Button("Use Chinese", size="sm") | |
| use_en_neg = gr.Button("Use English", size="sm") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Number(label="Seed", value=0, precision=0) | |
| cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=10, value=4, step=0.5) | |
| sigma_shift = gr.Slider(label="Sigma Shift", minimum=1, maximum=15, value=8, step=1) | |
| num_steps = gr.Slider(label="Steps", minimum=20, maximum=100, value=50, step=5) | |
| generate_btn = gr.Button("✨ Generate Image", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Image", type="pil", height=512) | |
| gen_status = gr.Textbox(label="Status", interactive=False, lines=2) | |
| gr.Markdown(""" | |
| --- | |
| **Resources:** [Z-Image-i2L (HuggingFace)](https://huggingface.co/DiffSynth-Studio/Z-Image-i2L) | | |
| [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) | | |
| **Settings:** CFG=4, Sigma Shift=8, Steps=50 | |
| """) | |
| # Event handlers | |
| generate_lora_btn.click( | |
| fn=image_to_lora, | |
| inputs=[input_gallery], | |
| outputs=[lora_output, lora_status] | |
| ) | |
| lora_output.change(fn=lambda x: x, inputs=[lora_output], outputs=[lora_input]) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[lora_input, prompt, negative_prompt, seed, cfg_scale, sigma_shift, num_steps], | |
| outputs=[output_image, gen_status] | |
| ) | |
| use_cn_neg.click(fn=lambda: NEGATIVE_PROMPT_CN, outputs=[negative_prompt]) | |
| use_en_neg.click(fn=lambda: NEGATIVE_PROMPT_EN, outputs=[negative_prompt]) | |
| return demo | |
| if __name__ == "__main__": | |
| print("Starting Gradio server...") | |
| demo = create_demo() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |