CharacterForgePro / src /qwen_image_edit_client.py
ghmk's picture
Deploy full Character Sheet Pro with HF auth
da23dfe
"""
Qwen-Image-Edit Client
======================
Client for Qwen-Image-Edit-2511 local image editing.
Supports multi-image editing with improved consistency.
GPU loading strategies (benchmarked on A6000 + A5000):
Pinned 2-GPU: 169.9s (4.25s/step) - 1.36x vs baseline
Balanced single-GPU: 184.4s (4.61s/step) - 1.25x vs baseline
CPU offload: 231.5s (5.79s/step) - baseline
"""
import logging
import time
import types
from typing import Optional, List
from PIL import Image
import torch
from .models import GenerationRequest, GenerationResult
logger = logging.getLogger(__name__)
class QwenImageEditClient:
"""
Client for Qwen-Image-Edit-2511 model.
Supports:
- Multi-image editing (up to multiple reference images)
- Precise text editing
- Improved character consistency
- LoRA integration
"""
# Model variants
MODELS = {
"full": "Qwen/Qwen-Image-Edit", # Official Qwen model
}
# Legacy compatibility
MODEL_ID = MODELS["full"]
# Aspect ratio to dimensions mapping (target output sizes)
ASPECT_RATIOS = {
"1:1": (1328, 1328),
"16:9": (1664, 928),
"9:16": (928, 1664),
"21:9": (1680, 720), # Cinematic ultra-wide
"3:2": (1584, 1056),
"2:3": (1056, 1584),
"3:4": (1104, 1472),
"4:3": (1472, 1104),
"4:5": (1056, 1320),
"5:4": (1320, 1056),
}
# Proven native generation resolution. Tested resolutions:
# 1104x1472 (3:4) → CLEAN output (face views in v1 test)
# 928x1664 (9:16) → VAE tiling noise / garbage
# 1328x1328 (1:1) → VAE tiling noise / garbage
# 896x1184 (auto) → garbage
# Always generate at 1104x1472, then crop+resize to target.
NATIVE_RESOLUTION = (1104, 1472)
# VRAM thresholds for loading strategies
# Qwen-Image-Edit components: transformer ~40.9GB, text_encoder ~16.6GB, VAE ~0.25GB
BALANCED_VRAM_THRESHOLD_GB = 45 # Single GPU balanced (needs ~42GB + headroom)
MAIN_GPU_MIN_VRAM_GB = 42 # Transformer + VAE minimum
ENCODER_GPU_MIN_VRAM_GB = 17 # Text encoder minimum
def __init__(
self,
model_variant: str = "full", # Use full model (~50GB)
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
enable_cpu_offload: bool = True,
encoder_device: Optional[str] = None,
):
"""
Initialize Qwen-Image-Edit client.
Args:
model_variant: Model variant ("full" for ~50GB)
device: Device to use for transformer+VAE (cuda or cuda:N)
dtype: Data type for model weights
enable_cpu_offload: Enable CPU offload to save VRAM
encoder_device: Explicit device for text_encoder (e.g. "cuda:3").
If None, auto-detected from available GPUs.
"""
self.model_variant = model_variant
self.device = device
self.dtype = dtype
self.enable_cpu_offload = enable_cpu_offload
self.encoder_device = encoder_device
self.pipe = None
self._loaded = False
self._loading_strategy = None
logger.info(f"QwenImageEditClient initialized (variant: {model_variant})")
@staticmethod
def _get_gpu_vram_gb(device_idx: int) -> float:
"""Get total VRAM in GB for a specific GPU."""
if not torch.cuda.is_available():
return 0.0
if device_idx >= torch.cuda.device_count():
return 0.0
return torch.cuda.get_device_properties(device_idx).total_memory / 1e9
def _get_vram_gb(self) -> float:
"""Get available VRAM in GB for the main target device."""
device_idx = self._parse_device_idx(self.device)
return self._get_gpu_vram_gb(device_idx)
@staticmethod
def _parse_device_idx(device: str) -> int:
"""Parse CUDA device index from device string."""
if device.startswith("cuda:"):
try:
return int(device.split(":")[1])
except (ValueError, IndexError):
pass
return 0
def _find_encoder_gpu(self, main_idx: int) -> Optional[int]:
"""Find a secondary GPU suitable for text_encoder (>= 17GB VRAM).
Prefers GPUs with more VRAM. Skips the main GPU.
"""
if not torch.cuda.is_available():
return None
candidates = []
for i in range(torch.cuda.device_count()):
if i == main_idx:
continue
vram = self._get_gpu_vram_gb(i)
if vram >= self.ENCODER_GPU_MIN_VRAM_GB:
name = torch.cuda.get_device_name(i)
candidates.append((i, vram, name))
if not candidates:
return None
# Pick the GPU with the most VRAM
candidates.sort(key=lambda x: x[1], reverse=True)
best = candidates[0]
logger.info(f"Found encoder GPU: cuda:{best[0]} ({best[2]}, {best[1]:.1f} GB)")
return best[0]
@staticmethod
def _patched_get_qwen_prompt_embeds(self, prompt, image=None, device=None, dtype=None):
"""Patched prompt encoding that routes inputs to text_encoder's device.
The original _get_qwen_prompt_embeds sends model_inputs to
execution_device (main GPU), then calls text_encoder on a different
GPU, causing a device mismatch. This patch:
1. Sends model_inputs to text_encoder's device for encoding
2. Moves outputs back to execution_device for the transformer
"""
te_device = next(self.text_encoder.parameters()).device
execution_device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
# Route to text_encoder's device, NOT execution_device
model_inputs = self.processor(
text=txt, images=image, padding=True, return_tensors="pt"
).to(te_device)
outputs = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
)
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(
hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [
torch.ones(e.size(0), dtype=torch.long, device=e.device)
for e in split_hidden_states
]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack([
torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])
for u in split_hidden_states
])
encoder_attention_mask = torch.stack([
torch.cat([u, u.new_zeros(max_seq_len - u.size(0))])
for u in attn_mask_list
])
# Move outputs to execution_device for transformer
prompt_embeds = prompt_embeds.to(dtype=dtype, device=execution_device)
encoder_attention_mask = encoder_attention_mask.to(device=execution_device)
return prompt_embeds, encoder_attention_mask
def _load_pinned_multi_gpu(self, model_id: str, main_idx: int, encoder_idx: int) -> bool:
"""Load with pinned multi-GPU: transformer+VAE on main, text_encoder on secondary.
Benchmarked at 169.9s (4.25s/step) - 1.36x faster than cpu_offload baseline.
"""
from diffusers import QwenImageEditPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
main_dev = f"cuda:{main_idx}"
enc_dev = f"cuda:{encoder_idx}"
logger.info(f"Loading pinned 2-GPU: transformer+VAE → {main_dev}, text_encoder → {enc_dev}")
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler")
tokenizer = Qwen2Tokenizer.from_pretrained(
model_id, subfolder="tokenizer")
processor = Qwen2VLProcessor.from_pretrained(
model_id, subfolder="processor")
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=self.dtype,
).to(enc_dev)
logger.info(f" text_encoder loaded on {enc_dev}")
transformer = QwenImageTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=self.dtype,
).to(main_dev)
logger.info(f" transformer loaded on {main_dev}")
vae = AutoencoderKLQwenImage.from_pretrained(
model_id, subfolder="vae", torch_dtype=self.dtype,
).to(main_dev)
vae.enable_tiling()
logger.info(f" VAE loaded on {main_dev}")
self.pipe = QwenImageEditPipeline(
scheduler=scheduler, vae=vae, text_encoder=text_encoder,
tokenizer=tokenizer, processor=processor, transformer=transformer,
)
# Fix 1: Override _execution_device to force main GPU
# Without this, pipeline returns text_encoder's device, causing VAE
# to receive tensors on the wrong GPU
main_device = torch.device(main_dev)
QwenImageEditPipeline._execution_device = property(lambda self: main_device)
# Fix 2: Monkey-patch prompt encoding to route inputs to text_encoder's device
self.pipe._get_qwen_prompt_embeds = types.MethodType(
self._patched_get_qwen_prompt_embeds, self.pipe)
self._loading_strategy = "pinned_multi_gpu"
logger.info(f"Pinned 2-GPU pipeline ready")
return True
def load_model(self) -> bool:
"""Load the model with the best available strategy.
Strategy priority (GPU strategies always attempted first):
1. Pinned 2-GPU: transformer+VAE on large GPU, text_encoder on secondary
(requires main GPU >= 42GB, secondary >= 17GB)
Benchmark: 169.9s (4.25s/step) - 1.36x
2. Balanced single-GPU: device_map="balanced" on single large GPU
(requires GPU >= 45GB)
Benchmark: 184.4s (4.61s/step) - 1.25x
3. CPU offload: model components shuttle between CPU and GPU
(requires enable_cpu_offload=True)
Benchmark: 231.5s (5.79s/step) - 1.0x baseline
4. Direct load: entire model on single GPU (may OOM)
"""
if self._loaded:
return True
try:
from diffusers import QwenImageEditPipeline
model_id = self.MODELS.get(self.model_variant, self.MODELS["full"])
main_idx = self._parse_device_idx(self.device)
main_vram = self._get_gpu_vram_gb(main_idx)
logger.info(f"Loading Qwen-Image-Edit ({self.model_variant}) from {model_id}...")
logger.info(f"Main GPU cuda:{main_idx}: {main_vram:.1f} GB VRAM")
start_time = time.time()
loaded = False
# Strategy 1: Pinned 2-GPU (always try first if main GPU is large enough)
if not loaded and main_vram >= self.MAIN_GPU_MIN_VRAM_GB:
encoder_idx = None
if self.encoder_device:
encoder_idx = self._parse_device_idx(self.encoder_device)
enc_vram = self._get_gpu_vram_gb(encoder_idx)
if enc_vram < self.ENCODER_GPU_MIN_VRAM_GB:
logger.warning(
f"Specified encoder device cuda:{encoder_idx} has "
f"{enc_vram:.1f} GB, need {self.ENCODER_GPU_MIN_VRAM_GB} GB. "
f"Falling back to auto-detect.")
encoder_idx = None
if encoder_idx is None:
encoder_idx = self._find_encoder_gpu(main_idx)
if encoder_idx is not None:
self._load_pinned_multi_gpu(model_id, main_idx, encoder_idx)
loaded = True
# Strategy 2: Balanced single-GPU
if not loaded and main_vram >= self.BALANCED_VRAM_THRESHOLD_GB:
max_mem_gb = int(main_vram - 4)
self.pipe = QwenImageEditPipeline.from_pretrained(
model_id, torch_dtype=self.dtype,
device_map="balanced",
max_memory={main_idx: f"{max_mem_gb}GiB"},
)
self._loading_strategy = "balanced_single"
logger.info(f"Loaded with device_map='balanced', max_memory={max_mem_gb}GiB")
loaded = True
# Strategy 3: CPU offload (only if allowed)
if not loaded and self.enable_cpu_offload:
self.pipe = QwenImageEditPipeline.from_pretrained(
model_id, torch_dtype=self.dtype)
self.pipe.enable_model_cpu_offload()
self._loading_strategy = "cpu_offload"
logger.info("Loaded with enable_model_cpu_offload()")
loaded = True
# Strategy 4: Direct load (last resort, may OOM)
if not loaded:
self.pipe = QwenImageEditPipeline.from_pretrained(
model_id, torch_dtype=self.dtype)
self.pipe.to(self.device)
self._loading_strategy = "direct"
logger.info(f"Loaded directly to {self.device}")
self.pipe.set_progress_bar_config(disable=None)
load_time = time.time() - start_time
logger.info(f"Qwen-Image-Edit loaded in {load_time:.1f}s (strategy: {self._loading_strategy})")
self._loaded = True
return True
except Exception as e:
logger.error(f"Failed to load Qwen-Image-Edit: {e}", exc_info=True)
return False
def unload_model(self):
"""Unload model from memory."""
if self.pipe is not None:
del self.pipe
self.pipe = None
self._loaded = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Qwen-Image-Edit-2511 unloaded")
def generate(
self,
request: GenerationRequest,
num_inference_steps: int = 40,
guidance_scale: float = 1.0,
true_cfg_scale: float = 4.0
) -> GenerationResult:
"""
Generate/edit image using Qwen-Image-Edit-2511.
Args:
request: GenerationRequest object
num_inference_steps: Number of denoising steps
guidance_scale: Classifier-free guidance scale
true_cfg_scale: True CFG scale for better control
Returns:
GenerationResult object
"""
if not self._loaded:
if not self.load_model():
return GenerationResult.error_result("Failed to load Qwen-Image-Edit-2511 model")
try:
start_time = time.time()
# Target dimensions for post-processing crop+resize
target_w, target_h = self._get_dimensions(request.aspect_ratio)
# Build input images list
input_images = []
if request.has_input_images:
input_images = [img for img in request.input_images if img is not None]
# Always generate at the proven native resolution (1104x1472).
# Other resolutions cause VAE tiling artifacts.
native_w, native_h = self.NATIVE_RESOLUTION
gen_kwargs = {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt or " ",
"height": native_h,
"width": native_w,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"true_cfg_scale": true_cfg_scale,
"num_images_per_prompt": 1,
"generator": torch.manual_seed(42),
}
# Qwen-Image-Edit is a single-image editor: use only the first image.
# The character service passes multiple references (face, body, costume)
# but the costume/view info is already encoded in the text prompt.
if input_images:
gen_kwargs["image"] = input_images[0]
logger.info(f"Generating with Qwen-Image-Edit: {request.prompt[:80]}...")
logger.info(f"Input images: {len(input_images)} (using first)")
logger.info(f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}")
# Generate at proven native resolution
with torch.inference_mode():
output = self.pipe(**gen_kwargs)
image = output.images[0]
generation_time = time.time() - start_time
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
# Crop + resize to requested aspect ratio
image = self._crop_and_resize(image, target_w, target_h)
logger.info(f"Post-processed to: {image.size}")
return GenerationResult.success_result(
image=image,
message=f"Generated with Qwen-Image-Edit in {generation_time:.2f}s",
generation_time=generation_time
)
except Exception as e:
logger.error(f"Qwen-Image-Edit generation failed: {e}", exc_info=True)
return GenerationResult.error_result(f"Qwen-Image-Edit error: {str(e)}")
@staticmethod
def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image:
"""Crop image to target aspect ratio, then resize to target dimensions.
Centers the crop on the image so equal amounts are trimmed from
each side. Uses LANCZOS for high-quality downscaling.
"""
src_w, src_h = image.size
target_ratio = target_w / target_h
src_ratio = src_w / src_h
if abs(target_ratio - src_ratio) < 0.01:
# Already the right aspect ratio, just resize
return image.resize((target_w, target_h), Image.LANCZOS)
if target_ratio < src_ratio:
# Target is taller/narrower than source → crop sides
crop_w = int(src_h * target_ratio)
offset = (src_w - crop_w) // 2
image = image.crop((offset, 0, offset + crop_w, src_h))
else:
# Target is wider than source → crop top/bottom
crop_h = int(src_w / target_ratio)
offset = (src_h - crop_h) // 2
image = image.crop((0, offset, src_w, offset + crop_h))
return image.resize((target_w, target_h), Image.LANCZOS)
def _get_dimensions(self, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
def is_healthy(self) -> bool:
"""Check if model is loaded and ready."""
return self._loaded and self.pipe is not None
@classmethod
def get_dimensions(cls, aspect_ratio: str) -> tuple:
"""Get pixel dimensions for aspect ratio."""
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))