Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,728 Bytes
da23dfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
"""
Z-Image Client
==============
Client for Z-Image (Tongyi-MAI) local image generation.
Supports text-to-image and image-to-image editing.
Z-Image is a 6B parameter model that achieves state-of-the-art quality
with only 8-9 inference steps, fitting in 16GB VRAM.
"""
import logging
import time
from typing import Optional, List
from PIL import Image
import torch
from .models import GenerationRequest, GenerationResult
logger = logging.getLogger(__name__)
class ZImageClient:
"""
Client for Z-Image models from Tongyi-MAI.
Supports:
- Text-to-image generation (ZImagePipeline)
- Image-to-image editing (ZImageImg2ImgPipeline)
- Multiple model variants (Turbo, Base, Edit, Omni)
"""
# Model variants
MODELS = {
# Turbo - Fast, distilled, 8-9 steps, fits 16GB VRAM
"turbo": "Tongyi-MAI/Z-Image-Turbo",
# Base - Quality-focused, more steps
"base": "Tongyi-MAI/Z-Image",
# Edit - Fine-tuned for instruction-following image editing
"edit": "Tongyi-MAI/Z-Image-Edit",
# Omni - Versatile, supports both generation and editing
"omni": "Tongyi-MAI/Z-Image-Omni-Base",
}
# Aspect ratio to dimensions mapping
# Z-Image supports 512x512 to 2048x2048
ASPECT_RATIOS = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"9:16": (768, 1344),
"21:9": (1536, 640), # Cinematic ultra-wide
"3:2": (1248, 832),
"2:3": (832, 1248),
"3:4": (896, 1152),
"4:3": (1152, 896),
"4:5": (896, 1120),
"5:4": (1120, 896),
}
# Default settings for each model variant
MODEL_DEFAULTS = {
"turbo": {"steps": 9, "guidance": 0.0}, # Fast, no CFG needed
"base": {"steps": 50, "guidance": 4.0}, # Quality-focused
"edit": {"steps": 28, "guidance": 3.5}, # Editing
"omni": {"steps": 28, "guidance": 3.5}, # Versatile
}
def __init__(
self,
model_variant: str = "turbo",
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
enable_cpu_offload: bool = True,
):
"""
Initialize Z-Image client.
Args:
model_variant: Model variant to use:
- "turbo": Fast, 9 steps, 16GB VRAM (RECOMMENDED)
- "base": Quality-focused, 50 steps
- "edit": Instruction-following image editing
- "omni": Versatile generation + editing
device: Device to use (cuda or cpu)
dtype: Data type for model weights (bfloat16 recommended)
enable_cpu_offload: Enable CPU offload to save VRAM
"""
self.model_variant = model_variant
self.device = device
self.dtype = dtype
self.enable_cpu_offload = enable_cpu_offload
self.pipe = None
self.pipe_img2img = None
self._loaded = False
# Get default settings for this variant
defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 9, "guidance": 0.0})
self.default_steps = defaults["steps"]
self.default_guidance = defaults["guidance"]
logger.info(f"ZImageClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})")
def load_model(self) -> bool:
"""Load the model into memory."""
if self._loaded:
return True
try:
# Get model ID for selected variant
model_id = self.MODELS.get(self.model_variant, self.MODELS["turbo"])
logger.info(f"Loading Z-Image ({self.model_variant}) from {model_id}...")
start_time = time.time()
# Import diffusers pipelines for Z-Image
# Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers
from diffusers import ZImagePipeline, ZImageImg2ImgPipeline
# Load text-to-image pipeline
self.pipe = ZImagePipeline.from_pretrained(
model_id,
torch_dtype=self.dtype,
)
# Load img2img pipeline (shares components)
self.pipe_img2img = ZImageImg2ImgPipeline.from_pretrained(
model_id,
torch_dtype=self.dtype,
# Share components to save memory
text_encoder=self.pipe.text_encoder,
tokenizer=self.pipe.tokenizer,
vae=self.pipe.vae,
transformer=self.pipe.transformer,
scheduler=self.pipe.scheduler,
)
# Apply memory optimization
if self.enable_cpu_offload:
self.pipe.enable_model_cpu_offload()
self.pipe_img2img.enable_model_cpu_offload()
logger.info("CPU offload enabled")
else:
self.pipe.to(self.device)
self.pipe_img2img.to(self.device)
logger.info(f"Model moved to {self.device}")
# Optional: Enable flash attention if available
try:
self.pipe.transformer.set_attention_backend("flash")
self.pipe_img2img.transformer.set_attention_backend("flash")
logger.info("Flash Attention enabled")
except Exception:
logger.info("Flash Attention not available, using default SDPA")
load_time = time.time() - start_time
logger.info(f"Z-Image ({self.model_variant}) loaded in {load_time:.1f}s")
# Validate by running a test generation
logger.info("Validating model with test generation...")
try:
test_result = self.pipe(
prompt="A simple test image",
height=256,
width=256,
guidance_scale=0.0,
num_inference_steps=2,
generator=torch.Generator(device="cpu").manual_seed(42),
)
if test_result.images[0] is not None:
logger.info("Model validation successful")
else:
logger.error("Model validation failed: no output image")
return False
except Exception as e:
logger.error(f"Model validation failed: {e}", exc_info=True)
return False
self._loaded = True
return True
except Exception as e:
logger.error(f"Failed to load Z-Image: {e}", exc_info=True)
return False
def unload_model(self):
"""Unload model from memory."""
if self.pipe is not None:
del self.pipe
self.pipe = None
if self.pipe_img2img is not None:
del self.pipe_img2img
self.pipe_img2img = None
self._loaded = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Z-Image unloaded")
def generate(
self,
request: GenerationRequest,
num_inference_steps: int = None,
guidance_scale: float = None
) -> GenerationResult:
"""
Generate image using Z-Image.
Args:
request: GenerationRequest object
num_inference_steps: Number of denoising steps (9 for turbo)
guidance_scale: Classifier-free guidance scale (0.0 for turbo)
Returns:
GenerationResult object
"""
if not self._loaded:
if not self.load_model():
return GenerationResult.error_result("Failed to load Z-Image model")
# Use model defaults if not specified
if num_inference_steps is None:
num_inference_steps = self.default_steps
if guidance_scale is None:
guidance_scale = self.default_guidance
try:
start_time = time.time()
# Get dimensions from aspect ratio
width, height = self._get_dimensions(request.aspect_ratio)
logger.info(f"Generating with Z-Image {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}")
# Check if we have input images (use img2img pipeline)
if request.has_input_images:
return self._generate_img2img(
request, width, height, num_inference_steps, guidance_scale, start_time
)
# Text-to-image generation
gen_kwargs = {
"prompt": request.prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": torch.Generator(device="cpu").manual_seed(42),
}
# Add negative prompt if present
if request.negative_prompt:
gen_kwargs["negative_prompt"] = request.negative_prompt
logger.info(f"Generating with Z-Image: {request.prompt[:80]}...")
# Generate
with torch.inference_mode():
output = self.pipe(**gen_kwargs)
image = output.images[0]
generation_time = time.time() - start_time
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
return GenerationResult.success_result(
image=image,
message=f"Generated with Z-Image ({self.model_variant}) in {generation_time:.2f}s",
generation_time=generation_time
)
except Exception as e:
logger.error(f"Z-Image generation failed: {e}", exc_info=True)
return GenerationResult.error_result(f"Z-Image error: {str(e)}")
def _generate_img2img(
self,
request: GenerationRequest,
width: int,
height: int,
num_inference_steps: int,
guidance_scale: float,
start_time: float
) -> GenerationResult:
"""Generate using img2img pipeline with input images."""
try:
# Get the first valid input image
input_image = None
for img in request.input_images:
if img is not None:
input_image = img
break
if input_image is None:
return GenerationResult.error_result("No valid input image provided")
# Resize input image to target dimensions
input_image = input_image.resize((width, height), Image.Resampling.LANCZOS)
# Build generation kwargs for img2img
gen_kwargs = {
"prompt": request.prompt,
"image": input_image,
"strength": 0.6, # How much to transform the image
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": torch.Generator(device="cpu").manual_seed(42),
}
# Add negative prompt if present
if request.negative_prompt:
gen_kwargs["negative_prompt"] = request.negative_prompt
logger.info(f"Generating img2img with Z-Image: {request.prompt[:80]}...")
# Generate
with torch.inference_mode():
output = self.pipe_img2img(**gen_kwargs)
image = output.images[0]
generation_time = time.time() - start_time
logger.info(f"Generated img2img in {generation_time:.2f}s: {image.size}")
return GenerationResult.success_result(
image=image,
message=f"Generated with Z-Image img2img ({self.model_variant}) in {generation_time:.2f}s",
generation_time=generation_time
)
except Exception as e:
logger.error(f"Z-Image img2img generation failed: {e}", exc_info=True)
return GenerationResult.error_result(f"Z-Image img2img error: {str(e)}")
def _get_dimensions(self, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
def is_healthy(self) -> bool:
"""Check if model is loaded and ready."""
return self._loaded and self.pipe is not None
@classmethod
def get_dimensions(cls, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|