Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Z-Image Client | |
| ============== | |
| Client for Z-Image (Tongyi-MAI) local image generation. | |
| Supports text-to-image and image-to-image editing. | |
| Z-Image is a 6B parameter model that achieves state-of-the-art quality | |
| with only 8-9 inference steps, fitting in 16GB VRAM. | |
| """ | |
| import logging | |
| import time | |
| from typing import Optional, List | |
| from PIL import Image | |
| import torch | |
| from .models import GenerationRequest, GenerationResult | |
| logger = logging.getLogger(__name__) | |
| class ZImageClient: | |
| """ | |
| Client for Z-Image models from Tongyi-MAI. | |
| Supports: | |
| - Text-to-image generation (ZImagePipeline) | |
| - Image-to-image editing (ZImageImg2ImgPipeline) | |
| - Multiple model variants (Turbo, Base, Edit, Omni) | |
| """ | |
| # Model variants | |
| MODELS = { | |
| # Turbo - Fast, distilled, 8-9 steps, fits 16GB VRAM | |
| "turbo": "Tongyi-MAI/Z-Image-Turbo", | |
| # Base - Quality-focused, more steps | |
| "base": "Tongyi-MAI/Z-Image", | |
| # Edit - Fine-tuned for instruction-following image editing | |
| "edit": "Tongyi-MAI/Z-Image-Edit", | |
| # Omni - Versatile, supports both generation and editing | |
| "omni": "Tongyi-MAI/Z-Image-Omni-Base", | |
| } | |
| # Aspect ratio to dimensions mapping | |
| # Z-Image supports 512x512 to 2048x2048 | |
| ASPECT_RATIOS = { | |
| "1:1": (1024, 1024), | |
| "16:9": (1344, 768), | |
| "9:16": (768, 1344), | |
| "21:9": (1536, 640), # Cinematic ultra-wide | |
| "3:2": (1248, 832), | |
| "2:3": (832, 1248), | |
| "3:4": (896, 1152), | |
| "4:3": (1152, 896), | |
| "4:5": (896, 1120), | |
| "5:4": (1120, 896), | |
| } | |
| # Default settings for each model variant | |
| MODEL_DEFAULTS = { | |
| "turbo": {"steps": 9, "guidance": 0.0}, # Fast, no CFG needed | |
| "base": {"steps": 50, "guidance": 4.0}, # Quality-focused | |
| "edit": {"steps": 28, "guidance": 3.5}, # Editing | |
| "omni": {"steps": 28, "guidance": 3.5}, # Versatile | |
| } | |
| def __init__( | |
| self, | |
| model_variant: str = "turbo", | |
| device: str = "cuda", | |
| dtype: torch.dtype = torch.bfloat16, | |
| enable_cpu_offload: bool = True, | |
| ): | |
| """ | |
| Initialize Z-Image client. | |
| Args: | |
| model_variant: Model variant to use: | |
| - "turbo": Fast, 9 steps, 16GB VRAM (RECOMMENDED) | |
| - "base": Quality-focused, 50 steps | |
| - "edit": Instruction-following image editing | |
| - "omni": Versatile generation + editing | |
| device: Device to use (cuda or cpu) | |
| dtype: Data type for model weights (bfloat16 recommended) | |
| enable_cpu_offload: Enable CPU offload to save VRAM | |
| """ | |
| self.model_variant = model_variant | |
| self.device = device | |
| self.dtype = dtype | |
| self.enable_cpu_offload = enable_cpu_offload | |
| self.pipe = None | |
| self.pipe_img2img = None | |
| self._loaded = False | |
| # Get default settings for this variant | |
| defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 9, "guidance": 0.0}) | |
| self.default_steps = defaults["steps"] | |
| self.default_guidance = defaults["guidance"] | |
| logger.info(f"ZImageClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})") | |
| def load_model(self) -> bool: | |
| """Load the model into memory.""" | |
| if self._loaded: | |
| return True | |
| try: | |
| # Get model ID for selected variant | |
| model_id = self.MODELS.get(self.model_variant, self.MODELS["turbo"]) | |
| logger.info(f"Loading Z-Image ({self.model_variant}) from {model_id}...") | |
| start_time = time.time() | |
| # Import diffusers pipelines for Z-Image | |
| # Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers | |
| from diffusers import ZImagePipeline, ZImageImg2ImgPipeline | |
| # Load text-to-image pipeline | |
| self.pipe = ZImagePipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=self.dtype, | |
| ) | |
| # Load img2img pipeline (shares components) | |
| self.pipe_img2img = ZImageImg2ImgPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=self.dtype, | |
| # Share components to save memory | |
| text_encoder=self.pipe.text_encoder, | |
| tokenizer=self.pipe.tokenizer, | |
| vae=self.pipe.vae, | |
| transformer=self.pipe.transformer, | |
| scheduler=self.pipe.scheduler, | |
| ) | |
| # Apply memory optimization | |
| if self.enable_cpu_offload: | |
| self.pipe.enable_model_cpu_offload() | |
| self.pipe_img2img.enable_model_cpu_offload() | |
| logger.info("CPU offload enabled") | |
| else: | |
| self.pipe.to(self.device) | |
| self.pipe_img2img.to(self.device) | |
| logger.info(f"Model moved to {self.device}") | |
| # Optional: Enable flash attention if available | |
| try: | |
| self.pipe.transformer.set_attention_backend("flash") | |
| self.pipe_img2img.transformer.set_attention_backend("flash") | |
| logger.info("Flash Attention enabled") | |
| except Exception: | |
| logger.info("Flash Attention not available, using default SDPA") | |
| load_time = time.time() - start_time | |
| logger.info(f"Z-Image ({self.model_variant}) loaded in {load_time:.1f}s") | |
| # Validate by running a test generation | |
| logger.info("Validating model with test generation...") | |
| try: | |
| test_result = self.pipe( | |
| prompt="A simple test image", | |
| height=256, | |
| width=256, | |
| guidance_scale=0.0, | |
| num_inference_steps=2, | |
| generator=torch.Generator(device="cpu").manual_seed(42), | |
| ) | |
| if test_result.images[0] is not None: | |
| logger.info("Model validation successful") | |
| else: | |
| logger.error("Model validation failed: no output image") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Model validation failed: {e}", exc_info=True) | |
| return False | |
| self._loaded = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load Z-Image: {e}", exc_info=True) | |
| return False | |
| def unload_model(self): | |
| """Unload model from memory.""" | |
| if self.pipe is not None: | |
| del self.pipe | |
| self.pipe = None | |
| if self.pipe_img2img is not None: | |
| del self.pipe_img2img | |
| self.pipe_img2img = None | |
| self._loaded = False | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Z-Image unloaded") | |
| def generate( | |
| self, | |
| request: GenerationRequest, | |
| num_inference_steps: int = None, | |
| guidance_scale: float = None | |
| ) -> GenerationResult: | |
| """ | |
| Generate image using Z-Image. | |
| Args: | |
| request: GenerationRequest object | |
| num_inference_steps: Number of denoising steps (9 for turbo) | |
| guidance_scale: Classifier-free guidance scale (0.0 for turbo) | |
| Returns: | |
| GenerationResult object | |
| """ | |
| if not self._loaded: | |
| if not self.load_model(): | |
| return GenerationResult.error_result("Failed to load Z-Image model") | |
| # Use model defaults if not specified | |
| if num_inference_steps is None: | |
| num_inference_steps = self.default_steps | |
| if guidance_scale is None: | |
| guidance_scale = self.default_guidance | |
| try: | |
| start_time = time.time() | |
| # Get dimensions from aspect ratio | |
| width, height = self._get_dimensions(request.aspect_ratio) | |
| logger.info(f"Generating with Z-Image {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}") | |
| # Check if we have input images (use img2img pipeline) | |
| if request.has_input_images: | |
| return self._generate_img2img( | |
| request, width, height, num_inference_steps, guidance_scale, start_time | |
| ) | |
| # Text-to-image generation | |
| gen_kwargs = { | |
| "prompt": request.prompt, | |
| "height": height, | |
| "width": width, | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "generator": torch.Generator(device="cpu").manual_seed(42), | |
| } | |
| # Add negative prompt if present | |
| if request.negative_prompt: | |
| gen_kwargs["negative_prompt"] = request.negative_prompt | |
| logger.info(f"Generating with Z-Image: {request.prompt[:80]}...") | |
| # Generate | |
| with torch.inference_mode(): | |
| output = self.pipe(**gen_kwargs) | |
| image = output.images[0] | |
| generation_time = time.time() - start_time | |
| logger.info(f"Generated in {generation_time:.2f}s: {image.size}") | |
| return GenerationResult.success_result( | |
| image=image, | |
| message=f"Generated with Z-Image ({self.model_variant}) in {generation_time:.2f}s", | |
| generation_time=generation_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Z-Image generation failed: {e}", exc_info=True) | |
| return GenerationResult.error_result(f"Z-Image error: {str(e)}") | |
| def _generate_img2img( | |
| self, | |
| request: GenerationRequest, | |
| width: int, | |
| height: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| start_time: float | |
| ) -> GenerationResult: | |
| """Generate using img2img pipeline with input images.""" | |
| try: | |
| # Get the first valid input image | |
| input_image = None | |
| for img in request.input_images: | |
| if img is not None: | |
| input_image = img | |
| break | |
| if input_image is None: | |
| return GenerationResult.error_result("No valid input image provided") | |
| # Resize input image to target dimensions | |
| input_image = input_image.resize((width, height), Image.Resampling.LANCZOS) | |
| # Build generation kwargs for img2img | |
| gen_kwargs = { | |
| "prompt": request.prompt, | |
| "image": input_image, | |
| "strength": 0.6, # How much to transform the image | |
| "height": height, | |
| "width": width, | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "generator": torch.Generator(device="cpu").manual_seed(42), | |
| } | |
| # Add negative prompt if present | |
| if request.negative_prompt: | |
| gen_kwargs["negative_prompt"] = request.negative_prompt | |
| logger.info(f"Generating img2img with Z-Image: {request.prompt[:80]}...") | |
| # Generate | |
| with torch.inference_mode(): | |
| output = self.pipe_img2img(**gen_kwargs) | |
| image = output.images[0] | |
| generation_time = time.time() - start_time | |
| logger.info(f"Generated img2img in {generation_time:.2f}s: {image.size}") | |
| return GenerationResult.success_result( | |
| image=image, | |
| message=f"Generated with Z-Image img2img ({self.model_variant}) in {generation_time:.2f}s", | |
| generation_time=generation_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Z-Image img2img generation failed: {e}", exc_info=True) | |
| return GenerationResult.error_result(f"Z-Image img2img error: {str(e)}") | |
| def _get_dimensions(self, aspect_ratio: str) -> tuple: | |
| """Get pixel dimensions for aspect ratio.""" | |
| ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio | |
| return self.ASPECT_RATIOS.get(ratio, (1024, 1024)) | |
| def is_healthy(self) -> bool: | |
| """Check if model is loaded and ready.""" | |
| return self._loaded and self.pipe is not None | |
| def get_dimensions(cls, aspect_ratio: str) -> tuple: | |
| """Get pixel dimensions for aspect ratio.""" | |
| ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio | |
| return cls.ASPECT_RATIOS.get(ratio, (1024, 1024)) | |