""" Qwen-Image-Edit Client ====================== Client for Qwen-Image-Edit-2511 local image editing. Supports multi-image editing with improved consistency. GPU loading strategies (benchmarked on A6000 + A5000): Pinned 2-GPU: 169.9s (4.25s/step) - 1.36x vs baseline Balanced single-GPU: 184.4s (4.61s/step) - 1.25x vs baseline CPU offload: 231.5s (5.79s/step) - baseline """ import logging import time import types from typing import Optional, List from PIL import Image import torch from .models import GenerationRequest, GenerationResult logger = logging.getLogger(__name__) class QwenImageEditClient: """ Client for Qwen-Image-Edit-2511 model. Supports: - Multi-image editing (up to multiple reference images) - Precise text editing - Improved character consistency - LoRA integration """ # Model variants MODELS = { "full": "Qwen/Qwen-Image-Edit", # Official Qwen model } # Legacy compatibility MODEL_ID = MODELS["full"] # Aspect ratio to dimensions mapping (target output sizes) ASPECT_RATIOS = { "1:1": (1328, 1328), "16:9": (1664, 928), "9:16": (928, 1664), "21:9": (1680, 720), # Cinematic ultra-wide "3:2": (1584, 1056), "2:3": (1056, 1584), "3:4": (1104, 1472), "4:3": (1472, 1104), "4:5": (1056, 1320), "5:4": (1320, 1056), } # Proven native generation resolution. Tested resolutions: # 1104x1472 (3:4) → CLEAN output (face views in v1 test) # 928x1664 (9:16) → VAE tiling noise / garbage # 1328x1328 (1:1) → VAE tiling noise / garbage # 896x1184 (auto) → garbage # Always generate at 1104x1472, then crop+resize to target. NATIVE_RESOLUTION = (1104, 1472) # VRAM thresholds for loading strategies # Qwen-Image-Edit components: transformer ~40.9GB, text_encoder ~16.6GB, VAE ~0.25GB BALANCED_VRAM_THRESHOLD_GB = 45 # Single GPU balanced (needs ~42GB + headroom) MAIN_GPU_MIN_VRAM_GB = 42 # Transformer + VAE minimum ENCODER_GPU_MIN_VRAM_GB = 17 # Text encoder minimum def __init__( self, model_variant: str = "full", # Use full model (~50GB) device: str = "cuda", dtype: torch.dtype = torch.bfloat16, enable_cpu_offload: bool = True, encoder_device: Optional[str] = None, ): """ Initialize Qwen-Image-Edit client. Args: model_variant: Model variant ("full" for ~50GB) device: Device to use for transformer+VAE (cuda or cuda:N) dtype: Data type for model weights enable_cpu_offload: Enable CPU offload to save VRAM encoder_device: Explicit device for text_encoder (e.g. "cuda:3"). If None, auto-detected from available GPUs. """ self.model_variant = model_variant self.device = device self.dtype = dtype self.enable_cpu_offload = enable_cpu_offload self.encoder_device = encoder_device self.pipe = None self._loaded = False self._loading_strategy = None logger.info(f"QwenImageEditClient initialized (variant: {model_variant})") @staticmethod def _get_gpu_vram_gb(device_idx: int) -> float: """Get total VRAM in GB for a specific GPU.""" if not torch.cuda.is_available(): return 0.0 if device_idx >= torch.cuda.device_count(): return 0.0 return torch.cuda.get_device_properties(device_idx).total_memory / 1e9 def _get_vram_gb(self) -> float: """Get available VRAM in GB for the main target device.""" device_idx = self._parse_device_idx(self.device) return self._get_gpu_vram_gb(device_idx) @staticmethod def _parse_device_idx(device: str) -> int: """Parse CUDA device index from device string.""" if device.startswith("cuda:"): try: return int(device.split(":")[1]) except (ValueError, IndexError): pass return 0 def _find_encoder_gpu(self, main_idx: int) -> Optional[int]: """Find a secondary GPU suitable for text_encoder (>= 17GB VRAM). Prefers GPUs with more VRAM. Skips the main GPU. """ if not torch.cuda.is_available(): return None candidates = [] for i in range(torch.cuda.device_count()): if i == main_idx: continue vram = self._get_gpu_vram_gb(i) if vram >= self.ENCODER_GPU_MIN_VRAM_GB: name = torch.cuda.get_device_name(i) candidates.append((i, vram, name)) if not candidates: return None # Pick the GPU with the most VRAM candidates.sort(key=lambda x: x[1], reverse=True) best = candidates[0] logger.info(f"Found encoder GPU: cuda:{best[0]} ({best[2]}, {best[1]:.1f} GB)") return best[0] @staticmethod def _patched_get_qwen_prompt_embeds(self, prompt, image=None, device=None, dtype=None): """Patched prompt encoding that routes inputs to text_encoder's device. The original _get_qwen_prompt_embeds sends model_inputs to execution_device (main GPU), then calls text_encoder on a different GPU, causing a device mismatch. This patch: 1. Sends model_inputs to text_encoder's device for encoding 2. Moves outputs back to execution_device for the transformer """ te_device = next(self.text_encoder.parameters()).device execution_device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] # Route to text_encoder's device, NOT execution_device model_inputs = self.processor( text=txt, images=image, padding=True, return_tensors="pt" ).to(te_device) outputs = self.text_encoder( input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True, ) hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden( hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [ torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states ] max_seq_len = max([e.size(0) for e in split_hidden_states]) prompt_embeds = torch.stack([ torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states ]) encoder_attention_mask = torch.stack([ torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list ]) # Move outputs to execution_device for transformer prompt_embeds = prompt_embeds.to(dtype=dtype, device=execution_device) encoder_attention_mask = encoder_attention_mask.to(device=execution_device) return prompt_embeds, encoder_attention_mask def _load_pinned_multi_gpu(self, model_id: str, main_idx: int, encoder_idx: int) -> bool: """Load with pinned multi-GPU: transformer+VAE on main, text_encoder on secondary. Benchmarked at 169.9s (4.25s/step) - 1.36x faster than cpu_offload baseline. """ from diffusers import QwenImageEditPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor main_dev = f"cuda:{main_idx}" enc_dev = f"cuda:{encoder_idx}" logger.info(f"Loading pinned 2-GPU: transformer+VAE → {main_dev}, text_encoder → {enc_dev}") scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler") tokenizer = Qwen2Tokenizer.from_pretrained( model_id, subfolder="tokenizer") processor = Qwen2VLProcessor.from_pretrained( model_id, subfolder="processor") text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_id, subfolder="text_encoder", torch_dtype=self.dtype, ).to(enc_dev) logger.info(f" text_encoder loaded on {enc_dev}") transformer = QwenImageTransformer2DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=self.dtype, ).to(main_dev) logger.info(f" transformer loaded on {main_dev}") vae = AutoencoderKLQwenImage.from_pretrained( model_id, subfolder="vae", torch_dtype=self.dtype, ).to(main_dev) vae.enable_tiling() logger.info(f" VAE loaded on {main_dev}") self.pipe = QwenImageEditPipeline( scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, processor=processor, transformer=transformer, ) # Fix 1: Override _execution_device to force main GPU # Without this, pipeline returns text_encoder's device, causing VAE # to receive tensors on the wrong GPU main_device = torch.device(main_dev) QwenImageEditPipeline._execution_device = property(lambda self: main_device) # Fix 2: Monkey-patch prompt encoding to route inputs to text_encoder's device self.pipe._get_qwen_prompt_embeds = types.MethodType( self._patched_get_qwen_prompt_embeds, self.pipe) self._loading_strategy = "pinned_multi_gpu" logger.info(f"Pinned 2-GPU pipeline ready") return True def load_model(self) -> bool: """Load the model with the best available strategy. Strategy priority (GPU strategies always attempted first): 1. Pinned 2-GPU: transformer+VAE on large GPU, text_encoder on secondary (requires main GPU >= 42GB, secondary >= 17GB) Benchmark: 169.9s (4.25s/step) - 1.36x 2. Balanced single-GPU: device_map="balanced" on single large GPU (requires GPU >= 45GB) Benchmark: 184.4s (4.61s/step) - 1.25x 3. CPU offload: model components shuttle between CPU and GPU (requires enable_cpu_offload=True) Benchmark: 231.5s (5.79s/step) - 1.0x baseline 4. Direct load: entire model on single GPU (may OOM) """ if self._loaded: return True try: from diffusers import QwenImageEditPipeline model_id = self.MODELS.get(self.model_variant, self.MODELS["full"]) main_idx = self._parse_device_idx(self.device) main_vram = self._get_gpu_vram_gb(main_idx) logger.info(f"Loading Qwen-Image-Edit ({self.model_variant}) from {model_id}...") logger.info(f"Main GPU cuda:{main_idx}: {main_vram:.1f} GB VRAM") start_time = time.time() loaded = False # Strategy 1: Pinned 2-GPU (always try first if main GPU is large enough) if not loaded and main_vram >= self.MAIN_GPU_MIN_VRAM_GB: encoder_idx = None if self.encoder_device: encoder_idx = self._parse_device_idx(self.encoder_device) enc_vram = self._get_gpu_vram_gb(encoder_idx) if enc_vram < self.ENCODER_GPU_MIN_VRAM_GB: logger.warning( f"Specified encoder device cuda:{encoder_idx} has " f"{enc_vram:.1f} GB, need {self.ENCODER_GPU_MIN_VRAM_GB} GB. " f"Falling back to auto-detect.") encoder_idx = None if encoder_idx is None: encoder_idx = self._find_encoder_gpu(main_idx) if encoder_idx is not None: self._load_pinned_multi_gpu(model_id, main_idx, encoder_idx) loaded = True # Strategy 2: Balanced single-GPU if not loaded and main_vram >= self.BALANCED_VRAM_THRESHOLD_GB: max_mem_gb = int(main_vram - 4) self.pipe = QwenImageEditPipeline.from_pretrained( model_id, torch_dtype=self.dtype, device_map="balanced", max_memory={main_idx: f"{max_mem_gb}GiB"}, ) self._loading_strategy = "balanced_single" logger.info(f"Loaded with device_map='balanced', max_memory={max_mem_gb}GiB") loaded = True # Strategy 3: CPU offload (only if allowed) if not loaded and self.enable_cpu_offload: self.pipe = QwenImageEditPipeline.from_pretrained( model_id, torch_dtype=self.dtype) self.pipe.enable_model_cpu_offload() self._loading_strategy = "cpu_offload" logger.info("Loaded with enable_model_cpu_offload()") loaded = True # Strategy 4: Direct load (last resort, may OOM) if not loaded: self.pipe = QwenImageEditPipeline.from_pretrained( model_id, torch_dtype=self.dtype) self.pipe.to(self.device) self._loading_strategy = "direct" logger.info(f"Loaded directly to {self.device}") self.pipe.set_progress_bar_config(disable=None) load_time = time.time() - start_time logger.info(f"Qwen-Image-Edit loaded in {load_time:.1f}s (strategy: {self._loading_strategy})") self._loaded = True return True except Exception as e: logger.error(f"Failed to load Qwen-Image-Edit: {e}", exc_info=True) return False def unload_model(self): """Unload model from memory.""" if self.pipe is not None: del self.pipe self.pipe = None self._loaded = False if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Qwen-Image-Edit-2511 unloaded") def generate( self, request: GenerationRequest, num_inference_steps: int = 40, guidance_scale: float = 1.0, true_cfg_scale: float = 4.0 ) -> GenerationResult: """ Generate/edit image using Qwen-Image-Edit-2511. Args: request: GenerationRequest object num_inference_steps: Number of denoising steps guidance_scale: Classifier-free guidance scale true_cfg_scale: True CFG scale for better control Returns: GenerationResult object """ if not self._loaded: if not self.load_model(): return GenerationResult.error_result("Failed to load Qwen-Image-Edit-2511 model") try: start_time = time.time() # Target dimensions for post-processing crop+resize target_w, target_h = self._get_dimensions(request.aspect_ratio) # Build input images list input_images = [] if request.has_input_images: input_images = [img for img in request.input_images if img is not None] # Always generate at the proven native resolution (1104x1472). # Other resolutions cause VAE tiling artifacts. native_w, native_h = self.NATIVE_RESOLUTION gen_kwargs = { "prompt": request.prompt, "negative_prompt": request.negative_prompt or " ", "height": native_h, "width": native_w, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "true_cfg_scale": true_cfg_scale, "num_images_per_prompt": 1, "generator": torch.manual_seed(42), } # Qwen-Image-Edit is a single-image editor: use only the first image. # The character service passes multiple references (face, body, costume) # but the costume/view info is already encoded in the text prompt. if input_images: gen_kwargs["image"] = input_images[0] logger.info(f"Generating with Qwen-Image-Edit: {request.prompt[:80]}...") logger.info(f"Input images: {len(input_images)} (using first)") logger.info(f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}") # Generate at proven native resolution with torch.inference_mode(): output = self.pipe(**gen_kwargs) image = output.images[0] generation_time = time.time() - start_time logger.info(f"Generated in {generation_time:.2f}s: {image.size}") # Crop + resize to requested aspect ratio image = self._crop_and_resize(image, target_w, target_h) logger.info(f"Post-processed to: {image.size}") return GenerationResult.success_result( image=image, message=f"Generated with Qwen-Image-Edit in {generation_time:.2f}s", generation_time=generation_time ) except Exception as e: logger.error(f"Qwen-Image-Edit generation failed: {e}", exc_info=True) return GenerationResult.error_result(f"Qwen-Image-Edit error: {str(e)}") @staticmethod def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image: """Crop image to target aspect ratio, then resize to target dimensions. Centers the crop on the image so equal amounts are trimmed from each side. Uses LANCZOS for high-quality downscaling. """ src_w, src_h = image.size target_ratio = target_w / target_h src_ratio = src_w / src_h if abs(target_ratio - src_ratio) < 0.01: # Already the right aspect ratio, just resize return image.resize((target_w, target_h), Image.LANCZOS) if target_ratio < src_ratio: # Target is taller/narrower than source → crop sides crop_w = int(src_h * target_ratio) offset = (src_w - crop_w) // 2 image = image.crop((offset, 0, offset + crop_w, src_h)) else: # Target is wider than source → crop top/bottom crop_h = int(src_w / target_ratio) offset = (src_h - crop_h) // 2 image = image.crop((0, offset, src_w, offset + crop_h)) return image.resize((target_w, target_h), Image.LANCZOS) def _get_dimensions(self, aspect_ratio: str) -> tuple: """Get pixel dimensions for aspect ratio.""" ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio return self.ASPECT_RATIOS.get(ratio, (1024, 1024)) def is_healthy(self) -> bool: """Check if model is loaded and ready.""" return self._loaded and self.pipe is not None @classmethod def get_dimensions(cls, aspect_ratio: str) -> tuple: """Get pixel dimensions for aspect ratio.""" ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))