""" Z-Image Client ============== Client for Z-Image (Tongyi-MAI) local image generation. Supports text-to-image and image-to-image editing. Z-Image is a 6B parameter model that achieves state-of-the-art quality with only 8-9 inference steps, fitting in 16GB VRAM. """ import logging import time from typing import Optional, List from PIL import Image import torch from .models import GenerationRequest, GenerationResult logger = logging.getLogger(__name__) class ZImageClient: """ Client for Z-Image models from Tongyi-MAI. Supports: - Text-to-image generation (ZImagePipeline) - Image-to-image editing (ZImageImg2ImgPipeline) - Multiple model variants (Turbo, Base, Edit, Omni) """ # Model variants MODELS = { # Turbo - Fast, distilled, 8-9 steps, fits 16GB VRAM "turbo": "Tongyi-MAI/Z-Image-Turbo", # Base - Quality-focused, more steps "base": "Tongyi-MAI/Z-Image", # Edit - Fine-tuned for instruction-following image editing "edit": "Tongyi-MAI/Z-Image-Edit", # Omni - Versatile, supports both generation and editing "omni": "Tongyi-MAI/Z-Image-Omni-Base", } # Aspect ratio to dimensions mapping # Z-Image supports 512x512 to 2048x2048 ASPECT_RATIOS = { "1:1": (1024, 1024), "16:9": (1344, 768), "9:16": (768, 1344), "21:9": (1536, 640), # Cinematic ultra-wide "3:2": (1248, 832), "2:3": (832, 1248), "3:4": (896, 1152), "4:3": (1152, 896), "4:5": (896, 1120), "5:4": (1120, 896), } # Default settings for each model variant MODEL_DEFAULTS = { "turbo": {"steps": 9, "guidance": 0.0}, # Fast, no CFG needed "base": {"steps": 50, "guidance": 4.0}, # Quality-focused "edit": {"steps": 28, "guidance": 3.5}, # Editing "omni": {"steps": 28, "guidance": 3.5}, # Versatile } def __init__( self, model_variant: str = "turbo", device: str = "cuda", dtype: torch.dtype = torch.bfloat16, enable_cpu_offload: bool = True, ): """ Initialize Z-Image client. Args: model_variant: Model variant to use: - "turbo": Fast, 9 steps, 16GB VRAM (RECOMMENDED) - "base": Quality-focused, 50 steps - "edit": Instruction-following image editing - "omni": Versatile generation + editing device: Device to use (cuda or cpu) dtype: Data type for model weights (bfloat16 recommended) enable_cpu_offload: Enable CPU offload to save VRAM """ self.model_variant = model_variant self.device = device self.dtype = dtype self.enable_cpu_offload = enable_cpu_offload self.pipe = None self.pipe_img2img = None self._loaded = False # Get default settings for this variant defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 9, "guidance": 0.0}) self.default_steps = defaults["steps"] self.default_guidance = defaults["guidance"] logger.info(f"ZImageClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})") def load_model(self) -> bool: """Load the model into memory.""" if self._loaded: return True try: # Get model ID for selected variant model_id = self.MODELS.get(self.model_variant, self.MODELS["turbo"]) logger.info(f"Loading Z-Image ({self.model_variant}) from {model_id}...") start_time = time.time() # Import diffusers pipelines for Z-Image # Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers from diffusers import ZImagePipeline, ZImageImg2ImgPipeline # Load text-to-image pipeline self.pipe = ZImagePipeline.from_pretrained( model_id, torch_dtype=self.dtype, ) # Load img2img pipeline (shares components) self.pipe_img2img = ZImageImg2ImgPipeline.from_pretrained( model_id, torch_dtype=self.dtype, # Share components to save memory text_encoder=self.pipe.text_encoder, tokenizer=self.pipe.tokenizer, vae=self.pipe.vae, transformer=self.pipe.transformer, scheduler=self.pipe.scheduler, ) # Apply memory optimization if self.enable_cpu_offload: self.pipe.enable_model_cpu_offload() self.pipe_img2img.enable_model_cpu_offload() logger.info("CPU offload enabled") else: self.pipe.to(self.device) self.pipe_img2img.to(self.device) logger.info(f"Model moved to {self.device}") # Optional: Enable flash attention if available try: self.pipe.transformer.set_attention_backend("flash") self.pipe_img2img.transformer.set_attention_backend("flash") logger.info("Flash Attention enabled") except Exception: logger.info("Flash Attention not available, using default SDPA") load_time = time.time() - start_time logger.info(f"Z-Image ({self.model_variant}) loaded in {load_time:.1f}s") # Validate by running a test generation logger.info("Validating model with test generation...") try: test_result = self.pipe( prompt="A simple test image", height=256, width=256, guidance_scale=0.0, num_inference_steps=2, generator=torch.Generator(device="cpu").manual_seed(42), ) if test_result.images[0] is not None: logger.info("Model validation successful") else: logger.error("Model validation failed: no output image") return False except Exception as e: logger.error(f"Model validation failed: {e}", exc_info=True) return False self._loaded = True return True except Exception as e: logger.error(f"Failed to load Z-Image: {e}", exc_info=True) return False def unload_model(self): """Unload model from memory.""" if self.pipe is not None: del self.pipe self.pipe = None if self.pipe_img2img is not None: del self.pipe_img2img self.pipe_img2img = None self._loaded = False if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Z-Image unloaded") def generate( self, request: GenerationRequest, num_inference_steps: int = None, guidance_scale: float = None ) -> GenerationResult: """ Generate image using Z-Image. Args: request: GenerationRequest object num_inference_steps: Number of denoising steps (9 for turbo) guidance_scale: Classifier-free guidance scale (0.0 for turbo) Returns: GenerationResult object """ if not self._loaded: if not self.load_model(): return GenerationResult.error_result("Failed to load Z-Image model") # Use model defaults if not specified if num_inference_steps is None: num_inference_steps = self.default_steps if guidance_scale is None: guidance_scale = self.default_guidance try: start_time = time.time() # Get dimensions from aspect ratio width, height = self._get_dimensions(request.aspect_ratio) logger.info(f"Generating with Z-Image {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}") # Check if we have input images (use img2img pipeline) if request.has_input_images: return self._generate_img2img( request, width, height, num_inference_steps, guidance_scale, start_time ) # Text-to-image generation gen_kwargs = { "prompt": request.prompt, "height": height, "width": width, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "generator": torch.Generator(device="cpu").manual_seed(42), } # Add negative prompt if present if request.negative_prompt: gen_kwargs["negative_prompt"] = request.negative_prompt logger.info(f"Generating with Z-Image: {request.prompt[:80]}...") # Generate with torch.inference_mode(): output = self.pipe(**gen_kwargs) image = output.images[0] generation_time = time.time() - start_time logger.info(f"Generated in {generation_time:.2f}s: {image.size}") return GenerationResult.success_result( image=image, message=f"Generated with Z-Image ({self.model_variant}) in {generation_time:.2f}s", generation_time=generation_time ) except Exception as e: logger.error(f"Z-Image generation failed: {e}", exc_info=True) return GenerationResult.error_result(f"Z-Image error: {str(e)}") def _generate_img2img( self, request: GenerationRequest, width: int, height: int, num_inference_steps: int, guidance_scale: float, start_time: float ) -> GenerationResult: """Generate using img2img pipeline with input images.""" try: # Get the first valid input image input_image = None for img in request.input_images: if img is not None: input_image = img break if input_image is None: return GenerationResult.error_result("No valid input image provided") # Resize input image to target dimensions input_image = input_image.resize((width, height), Image.Resampling.LANCZOS) # Build generation kwargs for img2img gen_kwargs = { "prompt": request.prompt, "image": input_image, "strength": 0.6, # How much to transform the image "height": height, "width": width, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "generator": torch.Generator(device="cpu").manual_seed(42), } # Add negative prompt if present if request.negative_prompt: gen_kwargs["negative_prompt"] = request.negative_prompt logger.info(f"Generating img2img with Z-Image: {request.prompt[:80]}...") # Generate with torch.inference_mode(): output = self.pipe_img2img(**gen_kwargs) image = output.images[0] generation_time = time.time() - start_time logger.info(f"Generated img2img in {generation_time:.2f}s: {image.size}") return GenerationResult.success_result( image=image, message=f"Generated with Z-Image img2img ({self.model_variant}) in {generation_time:.2f}s", generation_time=generation_time ) except Exception as e: logger.error(f"Z-Image img2img generation failed: {e}", exc_info=True) return GenerationResult.error_result(f"Z-Image img2img error: {str(e)}") def _get_dimensions(self, aspect_ratio: str) -> tuple: """Get pixel dimensions for aspect ratio.""" ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio return self.ASPECT_RATIOS.get(ratio, (1024, 1024)) def is_healthy(self) -> bool: """Check if model is loaded and ready.""" return self._loaded and self.pipe is not None @classmethod def get_dimensions(cls, aspect_ratio: str) -> tuple: """Get pixel dimensions for aspect ratio.""" ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))