CharacterForgePro / src /zimage_client.py
ghmk's picture
Deploy full Character Sheet Pro with HF auth
da23dfe
"""
Z-Image Client
==============
Client for Z-Image (Tongyi-MAI) local image generation.
Supports text-to-image and image-to-image editing.
Z-Image is a 6B parameter model that achieves state-of-the-art quality
with only 8-9 inference steps, fitting in 16GB VRAM.
"""
import logging
import time
from typing import Optional, List
from PIL import Image
import torch
from .models import GenerationRequest, GenerationResult
logger = logging.getLogger(__name__)
class ZImageClient:
"""
Client for Z-Image models from Tongyi-MAI.
Supports:
- Text-to-image generation (ZImagePipeline)
- Image-to-image editing (ZImageImg2ImgPipeline)
- Multiple model variants (Turbo, Base, Edit, Omni)
"""
# Model variants
MODELS = {
# Turbo - Fast, distilled, 8-9 steps, fits 16GB VRAM
"turbo": "Tongyi-MAI/Z-Image-Turbo",
# Base - Quality-focused, more steps
"base": "Tongyi-MAI/Z-Image",
# Edit - Fine-tuned for instruction-following image editing
"edit": "Tongyi-MAI/Z-Image-Edit",
# Omni - Versatile, supports both generation and editing
"omni": "Tongyi-MAI/Z-Image-Omni-Base",
}
# Aspect ratio to dimensions mapping
# Z-Image supports 512x512 to 2048x2048
ASPECT_RATIOS = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"9:16": (768, 1344),
"21:9": (1536, 640), # Cinematic ultra-wide
"3:2": (1248, 832),
"2:3": (832, 1248),
"3:4": (896, 1152),
"4:3": (1152, 896),
"4:5": (896, 1120),
"5:4": (1120, 896),
}
# Default settings for each model variant
MODEL_DEFAULTS = {
"turbo": {"steps": 9, "guidance": 0.0}, # Fast, no CFG needed
"base": {"steps": 50, "guidance": 4.0}, # Quality-focused
"edit": {"steps": 28, "guidance": 3.5}, # Editing
"omni": {"steps": 28, "guidance": 3.5}, # Versatile
}
def __init__(
self,
model_variant: str = "turbo",
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
enable_cpu_offload: bool = True,
):
"""
Initialize Z-Image client.
Args:
model_variant: Model variant to use:
- "turbo": Fast, 9 steps, 16GB VRAM (RECOMMENDED)
- "base": Quality-focused, 50 steps
- "edit": Instruction-following image editing
- "omni": Versatile generation + editing
device: Device to use (cuda or cpu)
dtype: Data type for model weights (bfloat16 recommended)
enable_cpu_offload: Enable CPU offload to save VRAM
"""
self.model_variant = model_variant
self.device = device
self.dtype = dtype
self.enable_cpu_offload = enable_cpu_offload
self.pipe = None
self.pipe_img2img = None
self._loaded = False
# Get default settings for this variant
defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 9, "guidance": 0.0})
self.default_steps = defaults["steps"]
self.default_guidance = defaults["guidance"]
logger.info(f"ZImageClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})")
def load_model(self) -> bool:
"""Load the model into memory."""
if self._loaded:
return True
try:
# Get model ID for selected variant
model_id = self.MODELS.get(self.model_variant, self.MODELS["turbo"])
logger.info(f"Loading Z-Image ({self.model_variant}) from {model_id}...")
start_time = time.time()
# Import diffusers pipelines for Z-Image
# Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers
from diffusers import ZImagePipeline, ZImageImg2ImgPipeline
# Load text-to-image pipeline
self.pipe = ZImagePipeline.from_pretrained(
model_id,
torch_dtype=self.dtype,
)
# Load img2img pipeline (shares components)
self.pipe_img2img = ZImageImg2ImgPipeline.from_pretrained(
model_id,
torch_dtype=self.dtype,
# Share components to save memory
text_encoder=self.pipe.text_encoder,
tokenizer=self.pipe.tokenizer,
vae=self.pipe.vae,
transformer=self.pipe.transformer,
scheduler=self.pipe.scheduler,
)
# Apply memory optimization
if self.enable_cpu_offload:
self.pipe.enable_model_cpu_offload()
self.pipe_img2img.enable_model_cpu_offload()
logger.info("CPU offload enabled")
else:
self.pipe.to(self.device)
self.pipe_img2img.to(self.device)
logger.info(f"Model moved to {self.device}")
# Optional: Enable flash attention if available
try:
self.pipe.transformer.set_attention_backend("flash")
self.pipe_img2img.transformer.set_attention_backend("flash")
logger.info("Flash Attention enabled")
except Exception:
logger.info("Flash Attention not available, using default SDPA")
load_time = time.time() - start_time
logger.info(f"Z-Image ({self.model_variant}) loaded in {load_time:.1f}s")
# Validate by running a test generation
logger.info("Validating model with test generation...")
try:
test_result = self.pipe(
prompt="A simple test image",
height=256,
width=256,
guidance_scale=0.0,
num_inference_steps=2,
generator=torch.Generator(device="cpu").manual_seed(42),
)
if test_result.images[0] is not None:
logger.info("Model validation successful")
else:
logger.error("Model validation failed: no output image")
return False
except Exception as e:
logger.error(f"Model validation failed: {e}", exc_info=True)
return False
self._loaded = True
return True
except Exception as e:
logger.error(f"Failed to load Z-Image: {e}", exc_info=True)
return False
def unload_model(self):
"""Unload model from memory."""
if self.pipe is not None:
del self.pipe
self.pipe = None
if self.pipe_img2img is not None:
del self.pipe_img2img
self.pipe_img2img = None
self._loaded = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Z-Image unloaded")
def generate(
self,
request: GenerationRequest,
num_inference_steps: int = None,
guidance_scale: float = None
) -> GenerationResult:
"""
Generate image using Z-Image.
Args:
request: GenerationRequest object
num_inference_steps: Number of denoising steps (9 for turbo)
guidance_scale: Classifier-free guidance scale (0.0 for turbo)
Returns:
GenerationResult object
"""
if not self._loaded:
if not self.load_model():
return GenerationResult.error_result("Failed to load Z-Image model")
# Use model defaults if not specified
if num_inference_steps is None:
num_inference_steps = self.default_steps
if guidance_scale is None:
guidance_scale = self.default_guidance
try:
start_time = time.time()
# Get dimensions from aspect ratio
width, height = self._get_dimensions(request.aspect_ratio)
logger.info(f"Generating with Z-Image {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}")
# Check if we have input images (use img2img pipeline)
if request.has_input_images:
return self._generate_img2img(
request, width, height, num_inference_steps, guidance_scale, start_time
)
# Text-to-image generation
gen_kwargs = {
"prompt": request.prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": torch.Generator(device="cpu").manual_seed(42),
}
# Add negative prompt if present
if request.negative_prompt:
gen_kwargs["negative_prompt"] = request.negative_prompt
logger.info(f"Generating with Z-Image: {request.prompt[:80]}...")
# Generate
with torch.inference_mode():
output = self.pipe(**gen_kwargs)
image = output.images[0]
generation_time = time.time() - start_time
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
return GenerationResult.success_result(
image=image,
message=f"Generated with Z-Image ({self.model_variant}) in {generation_time:.2f}s",
generation_time=generation_time
)
except Exception as e:
logger.error(f"Z-Image generation failed: {e}", exc_info=True)
return GenerationResult.error_result(f"Z-Image error: {str(e)}")
def _generate_img2img(
self,
request: GenerationRequest,
width: int,
height: int,
num_inference_steps: int,
guidance_scale: float,
start_time: float
) -> GenerationResult:
"""Generate using img2img pipeline with input images."""
try:
# Get the first valid input image
input_image = None
for img in request.input_images:
if img is not None:
input_image = img
break
if input_image is None:
return GenerationResult.error_result("No valid input image provided")
# Resize input image to target dimensions
input_image = input_image.resize((width, height), Image.Resampling.LANCZOS)
# Build generation kwargs for img2img
gen_kwargs = {
"prompt": request.prompt,
"image": input_image,
"strength": 0.6, # How much to transform the image
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": torch.Generator(device="cpu").manual_seed(42),
}
# Add negative prompt if present
if request.negative_prompt:
gen_kwargs["negative_prompt"] = request.negative_prompt
logger.info(f"Generating img2img with Z-Image: {request.prompt[:80]}...")
# Generate
with torch.inference_mode():
output = self.pipe_img2img(**gen_kwargs)
image = output.images[0]
generation_time = time.time() - start_time
logger.info(f"Generated img2img in {generation_time:.2f}s: {image.size}")
return GenerationResult.success_result(
image=image,
message=f"Generated with Z-Image img2img ({self.model_variant}) in {generation_time:.2f}s",
generation_time=generation_time
)
except Exception as e:
logger.error(f"Z-Image img2img generation failed: {e}", exc_info=True)
return GenerationResult.error_result(f"Z-Image img2img error: {str(e)}")
def _get_dimensions(self, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
def is_healthy(self) -> bool:
"""Check if model is loaded and ready."""
return self._loaded and self.pipe is not None
@classmethod
def get_dimensions(cls, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))