Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Qwen-Image-Edit Client | |
| ====================== | |
| Client for Qwen-Image-Edit-2511 local image editing. | |
| Supports multi-image editing with improved consistency. | |
| GPU loading strategies (benchmarked on A6000 + A5000): | |
| Pinned 2-GPU: 169.9s (4.25s/step) - 1.36x vs baseline | |
| Balanced single-GPU: 184.4s (4.61s/step) - 1.25x vs baseline | |
| CPU offload: 231.5s (5.79s/step) - baseline | |
| """ | |
| import logging | |
| import time | |
| import types | |
| from typing import Optional, List | |
| from PIL import Image | |
| import torch | |
| from .models import GenerationRequest, GenerationResult | |
| logger = logging.getLogger(__name__) | |
| class QwenImageEditClient: | |
| """ | |
| Client for Qwen-Image-Edit-2511 model. | |
| Supports: | |
| - Multi-image editing (up to multiple reference images) | |
| - Precise text editing | |
| - Improved character consistency | |
| - LoRA integration | |
| """ | |
| # Model variants | |
| MODELS = { | |
| "full": "Qwen/Qwen-Image-Edit", # Official Qwen model | |
| } | |
| # Legacy compatibility | |
| MODEL_ID = MODELS["full"] | |
| # Aspect ratio to dimensions mapping (target output sizes) | |
| ASPECT_RATIOS = { | |
| "1:1": (1328, 1328), | |
| "16:9": (1664, 928), | |
| "9:16": (928, 1664), | |
| "21:9": (1680, 720), # Cinematic ultra-wide | |
| "3:2": (1584, 1056), | |
| "2:3": (1056, 1584), | |
| "3:4": (1104, 1472), | |
| "4:3": (1472, 1104), | |
| "4:5": (1056, 1320), | |
| "5:4": (1320, 1056), | |
| } | |
| # Proven native generation resolution. Tested resolutions: | |
| # 1104x1472 (3:4) → CLEAN output (face views in v1 test) | |
| # 928x1664 (9:16) → VAE tiling noise / garbage | |
| # 1328x1328 (1:1) → VAE tiling noise / garbage | |
| # 896x1184 (auto) → garbage | |
| # Always generate at 1104x1472, then crop+resize to target. | |
| NATIVE_RESOLUTION = (1104, 1472) | |
| # VRAM thresholds for loading strategies | |
| # Qwen-Image-Edit components: transformer ~40.9GB, text_encoder ~16.6GB, VAE ~0.25GB | |
| BALANCED_VRAM_THRESHOLD_GB = 45 # Single GPU balanced (needs ~42GB + headroom) | |
| MAIN_GPU_MIN_VRAM_GB = 42 # Transformer + VAE minimum | |
| ENCODER_GPU_MIN_VRAM_GB = 17 # Text encoder minimum | |
| def __init__( | |
| self, | |
| model_variant: str = "full", # Use full model (~50GB) | |
| device: str = "cuda", | |
| dtype: torch.dtype = torch.bfloat16, | |
| enable_cpu_offload: bool = True, | |
| encoder_device: Optional[str] = None, | |
| ): | |
| """ | |
| Initialize Qwen-Image-Edit client. | |
| Args: | |
| model_variant: Model variant ("full" for ~50GB) | |
| device: Device to use for transformer+VAE (cuda or cuda:N) | |
| dtype: Data type for model weights | |
| enable_cpu_offload: Enable CPU offload to save VRAM | |
| encoder_device: Explicit device for text_encoder (e.g. "cuda:3"). | |
| If None, auto-detected from available GPUs. | |
| """ | |
| self.model_variant = model_variant | |
| self.device = device | |
| self.dtype = dtype | |
| self.enable_cpu_offload = enable_cpu_offload | |
| self.encoder_device = encoder_device | |
| self.pipe = None | |
| self._loaded = False | |
| self._loading_strategy = None | |
| logger.info(f"QwenImageEditClient initialized (variant: {model_variant})") | |
| def _get_gpu_vram_gb(device_idx: int) -> float: | |
| """Get total VRAM in GB for a specific GPU.""" | |
| if not torch.cuda.is_available(): | |
| return 0.0 | |
| if device_idx >= torch.cuda.device_count(): | |
| return 0.0 | |
| return torch.cuda.get_device_properties(device_idx).total_memory / 1e9 | |
| def _get_vram_gb(self) -> float: | |
| """Get available VRAM in GB for the main target device.""" | |
| device_idx = self._parse_device_idx(self.device) | |
| return self._get_gpu_vram_gb(device_idx) | |
| def _parse_device_idx(device: str) -> int: | |
| """Parse CUDA device index from device string.""" | |
| if device.startswith("cuda:"): | |
| try: | |
| return int(device.split(":")[1]) | |
| except (ValueError, IndexError): | |
| pass | |
| return 0 | |
| def _find_encoder_gpu(self, main_idx: int) -> Optional[int]: | |
| """Find a secondary GPU suitable for text_encoder (>= 17GB VRAM). | |
| Prefers GPUs with more VRAM. Skips the main GPU. | |
| """ | |
| if not torch.cuda.is_available(): | |
| return None | |
| candidates = [] | |
| for i in range(torch.cuda.device_count()): | |
| if i == main_idx: | |
| continue | |
| vram = self._get_gpu_vram_gb(i) | |
| if vram >= self.ENCODER_GPU_MIN_VRAM_GB: | |
| name = torch.cuda.get_device_name(i) | |
| candidates.append((i, vram, name)) | |
| if not candidates: | |
| return None | |
| # Pick the GPU with the most VRAM | |
| candidates.sort(key=lambda x: x[1], reverse=True) | |
| best = candidates[0] | |
| logger.info(f"Found encoder GPU: cuda:{best[0]} ({best[2]}, {best[1]:.1f} GB)") | |
| return best[0] | |
| def _patched_get_qwen_prompt_embeds(self, prompt, image=None, device=None, dtype=None): | |
| """Patched prompt encoding that routes inputs to text_encoder's device. | |
| The original _get_qwen_prompt_embeds sends model_inputs to | |
| execution_device (main GPU), then calls text_encoder on a different | |
| GPU, causing a device mismatch. This patch: | |
| 1. Sends model_inputs to text_encoder's device for encoding | |
| 2. Moves outputs back to execution_device for the transformer | |
| """ | |
| te_device = next(self.text_encoder.parameters()).device | |
| execution_device = device or self._execution_device | |
| dtype = dtype or self.text_encoder.dtype | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| template = self.prompt_template_encode | |
| drop_idx = self.prompt_template_encode_start_idx | |
| txt = [template.format(e) for e in prompt] | |
| # Route to text_encoder's device, NOT execution_device | |
| model_inputs = self.processor( | |
| text=txt, images=image, padding=True, return_tensors="pt" | |
| ).to(te_device) | |
| outputs = self.text_encoder( | |
| input_ids=model_inputs.input_ids, | |
| attention_mask=model_inputs.attention_mask, | |
| pixel_values=model_inputs.pixel_values, | |
| image_grid_thw=model_inputs.image_grid_thw, | |
| output_hidden_states=True, | |
| ) | |
| hidden_states = outputs.hidden_states[-1] | |
| split_hidden_states = self._extract_masked_hidden( | |
| hidden_states, model_inputs.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| attn_mask_list = [ | |
| torch.ones(e.size(0), dtype=torch.long, device=e.device) | |
| for e in split_hidden_states | |
| ] | |
| max_seq_len = max([e.size(0) for e in split_hidden_states]) | |
| prompt_embeds = torch.stack([ | |
| torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) | |
| for u in split_hidden_states | |
| ]) | |
| encoder_attention_mask = torch.stack([ | |
| torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) | |
| for u in attn_mask_list | |
| ]) | |
| # Move outputs to execution_device for transformer | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=execution_device) | |
| encoder_attention_mask = encoder_attention_mask.to(device=execution_device) | |
| return prompt_embeds, encoder_attention_mask | |
| def _load_pinned_multi_gpu(self, model_id: str, main_idx: int, encoder_idx: int) -> bool: | |
| """Load with pinned multi-GPU: transformer+VAE on main, text_encoder on secondary. | |
| Benchmarked at 169.9s (4.25s/step) - 1.36x faster than cpu_offload baseline. | |
| """ | |
| from diffusers import QwenImageEditPipeline | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel | |
| from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage | |
| from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor | |
| main_dev = f"cuda:{main_idx}" | |
| enc_dev = f"cuda:{encoder_idx}" | |
| logger.info(f"Loading pinned 2-GPU: transformer+VAE → {main_dev}, text_encoder → {enc_dev}") | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| model_id, subfolder="scheduler") | |
| tokenizer = Qwen2Tokenizer.from_pretrained( | |
| model_id, subfolder="tokenizer") | |
| processor = Qwen2VLProcessor.from_pretrained( | |
| model_id, subfolder="processor") | |
| text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| model_id, subfolder="text_encoder", torch_dtype=self.dtype, | |
| ).to(enc_dev) | |
| logger.info(f" text_encoder loaded on {enc_dev}") | |
| transformer = QwenImageTransformer2DModel.from_pretrained( | |
| model_id, subfolder="transformer", torch_dtype=self.dtype, | |
| ).to(main_dev) | |
| logger.info(f" transformer loaded on {main_dev}") | |
| vae = AutoencoderKLQwenImage.from_pretrained( | |
| model_id, subfolder="vae", torch_dtype=self.dtype, | |
| ).to(main_dev) | |
| vae.enable_tiling() | |
| logger.info(f" VAE loaded on {main_dev}") | |
| self.pipe = QwenImageEditPipeline( | |
| scheduler=scheduler, vae=vae, text_encoder=text_encoder, | |
| tokenizer=tokenizer, processor=processor, transformer=transformer, | |
| ) | |
| # Fix 1: Override _execution_device to force main GPU | |
| # Without this, pipeline returns text_encoder's device, causing VAE | |
| # to receive tensors on the wrong GPU | |
| main_device = torch.device(main_dev) | |
| QwenImageEditPipeline._execution_device = property(lambda self: main_device) | |
| # Fix 2: Monkey-patch prompt encoding to route inputs to text_encoder's device | |
| self.pipe._get_qwen_prompt_embeds = types.MethodType( | |
| self._patched_get_qwen_prompt_embeds, self.pipe) | |
| self._loading_strategy = "pinned_multi_gpu" | |
| logger.info(f"Pinned 2-GPU pipeline ready") | |
| return True | |
| def load_model(self) -> bool: | |
| """Load the model with the best available strategy. | |
| Strategy priority (GPU strategies always attempted first): | |
| 1. Pinned 2-GPU: transformer+VAE on large GPU, text_encoder on secondary | |
| (requires main GPU >= 42GB, secondary >= 17GB) | |
| Benchmark: 169.9s (4.25s/step) - 1.36x | |
| 2. Balanced single-GPU: device_map="balanced" on single large GPU | |
| (requires GPU >= 45GB) | |
| Benchmark: 184.4s (4.61s/step) - 1.25x | |
| 3. CPU offload: model components shuttle between CPU and GPU | |
| (requires enable_cpu_offload=True) | |
| Benchmark: 231.5s (5.79s/step) - 1.0x baseline | |
| 4. Direct load: entire model on single GPU (may OOM) | |
| """ | |
| if self._loaded: | |
| return True | |
| try: | |
| from diffusers import QwenImageEditPipeline | |
| model_id = self.MODELS.get(self.model_variant, self.MODELS["full"]) | |
| main_idx = self._parse_device_idx(self.device) | |
| main_vram = self._get_gpu_vram_gb(main_idx) | |
| logger.info(f"Loading Qwen-Image-Edit ({self.model_variant}) from {model_id}...") | |
| logger.info(f"Main GPU cuda:{main_idx}: {main_vram:.1f} GB VRAM") | |
| start_time = time.time() | |
| loaded = False | |
| # Strategy 1: Pinned 2-GPU (always try first if main GPU is large enough) | |
| if not loaded and main_vram >= self.MAIN_GPU_MIN_VRAM_GB: | |
| encoder_idx = None | |
| if self.encoder_device: | |
| encoder_idx = self._parse_device_idx(self.encoder_device) | |
| enc_vram = self._get_gpu_vram_gb(encoder_idx) | |
| if enc_vram < self.ENCODER_GPU_MIN_VRAM_GB: | |
| logger.warning( | |
| f"Specified encoder device cuda:{encoder_idx} has " | |
| f"{enc_vram:.1f} GB, need {self.ENCODER_GPU_MIN_VRAM_GB} GB. " | |
| f"Falling back to auto-detect.") | |
| encoder_idx = None | |
| if encoder_idx is None: | |
| encoder_idx = self._find_encoder_gpu(main_idx) | |
| if encoder_idx is not None: | |
| self._load_pinned_multi_gpu(model_id, main_idx, encoder_idx) | |
| loaded = True | |
| # Strategy 2: Balanced single-GPU | |
| if not loaded and main_vram >= self.BALANCED_VRAM_THRESHOLD_GB: | |
| max_mem_gb = int(main_vram - 4) | |
| self.pipe = QwenImageEditPipeline.from_pretrained( | |
| model_id, torch_dtype=self.dtype, | |
| device_map="balanced", | |
| max_memory={main_idx: f"{max_mem_gb}GiB"}, | |
| ) | |
| self._loading_strategy = "balanced_single" | |
| logger.info(f"Loaded with device_map='balanced', max_memory={max_mem_gb}GiB") | |
| loaded = True | |
| # Strategy 3: CPU offload (only if allowed) | |
| if not loaded and self.enable_cpu_offload: | |
| self.pipe = QwenImageEditPipeline.from_pretrained( | |
| model_id, torch_dtype=self.dtype) | |
| self.pipe.enable_model_cpu_offload() | |
| self._loading_strategy = "cpu_offload" | |
| logger.info("Loaded with enable_model_cpu_offload()") | |
| loaded = True | |
| # Strategy 4: Direct load (last resort, may OOM) | |
| if not loaded: | |
| self.pipe = QwenImageEditPipeline.from_pretrained( | |
| model_id, torch_dtype=self.dtype) | |
| self.pipe.to(self.device) | |
| self._loading_strategy = "direct" | |
| logger.info(f"Loaded directly to {self.device}") | |
| self.pipe.set_progress_bar_config(disable=None) | |
| load_time = time.time() - start_time | |
| logger.info(f"Qwen-Image-Edit loaded in {load_time:.1f}s (strategy: {self._loading_strategy})") | |
| self._loaded = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load Qwen-Image-Edit: {e}", exc_info=True) | |
| return False | |
| def unload_model(self): | |
| """Unload model from memory.""" | |
| if self.pipe is not None: | |
| del self.pipe | |
| self.pipe = None | |
| self._loaded = False | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Qwen-Image-Edit-2511 unloaded") | |
| def generate( | |
| self, | |
| request: GenerationRequest, | |
| num_inference_steps: int = 40, | |
| guidance_scale: float = 1.0, | |
| true_cfg_scale: float = 4.0 | |
| ) -> GenerationResult: | |
| """ | |
| Generate/edit image using Qwen-Image-Edit-2511. | |
| Args: | |
| request: GenerationRequest object | |
| num_inference_steps: Number of denoising steps | |
| guidance_scale: Classifier-free guidance scale | |
| true_cfg_scale: True CFG scale for better control | |
| Returns: | |
| GenerationResult object | |
| """ | |
| if not self._loaded: | |
| if not self.load_model(): | |
| return GenerationResult.error_result("Failed to load Qwen-Image-Edit-2511 model") | |
| try: | |
| start_time = time.time() | |
| # Target dimensions for post-processing crop+resize | |
| target_w, target_h = self._get_dimensions(request.aspect_ratio) | |
| # Build input images list | |
| input_images = [] | |
| if request.has_input_images: | |
| input_images = [img for img in request.input_images if img is not None] | |
| # Always generate at the proven native resolution (1104x1472). | |
| # Other resolutions cause VAE tiling artifacts. | |
| native_w, native_h = self.NATIVE_RESOLUTION | |
| gen_kwargs = { | |
| "prompt": request.prompt, | |
| "negative_prompt": request.negative_prompt or " ", | |
| "height": native_h, | |
| "width": native_w, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "true_cfg_scale": true_cfg_scale, | |
| "num_images_per_prompt": 1, | |
| "generator": torch.manual_seed(42), | |
| } | |
| # Qwen-Image-Edit is a single-image editor: use only the first image. | |
| # The character service passes multiple references (face, body, costume) | |
| # but the costume/view info is already encoded in the text prompt. | |
| if input_images: | |
| gen_kwargs["image"] = input_images[0] | |
| logger.info(f"Generating with Qwen-Image-Edit: {request.prompt[:80]}...") | |
| logger.info(f"Input images: {len(input_images)} (using first)") | |
| logger.info(f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}") | |
| # Generate at proven native resolution | |
| with torch.inference_mode(): | |
| output = self.pipe(**gen_kwargs) | |
| image = output.images[0] | |
| generation_time = time.time() - start_time | |
| logger.info(f"Generated in {generation_time:.2f}s: {image.size}") | |
| # Crop + resize to requested aspect ratio | |
| image = self._crop_and_resize(image, target_w, target_h) | |
| logger.info(f"Post-processed to: {image.size}") | |
| return GenerationResult.success_result( | |
| image=image, | |
| message=f"Generated with Qwen-Image-Edit in {generation_time:.2f}s", | |
| generation_time=generation_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Qwen-Image-Edit generation failed: {e}", exc_info=True) | |
| return GenerationResult.error_result(f"Qwen-Image-Edit error: {str(e)}") | |
| def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image: | |
| """Crop image to target aspect ratio, then resize to target dimensions. | |
| Centers the crop on the image so equal amounts are trimmed from | |
| each side. Uses LANCZOS for high-quality downscaling. | |
| """ | |
| src_w, src_h = image.size | |
| target_ratio = target_w / target_h | |
| src_ratio = src_w / src_h | |
| if abs(target_ratio - src_ratio) < 0.01: | |
| # Already the right aspect ratio, just resize | |
| return image.resize((target_w, target_h), Image.LANCZOS) | |
| if target_ratio < src_ratio: | |
| # Target is taller/narrower than source → crop sides | |
| crop_w = int(src_h * target_ratio) | |
| offset = (src_w - crop_w) // 2 | |
| image = image.crop((offset, 0, offset + crop_w, src_h)) | |
| else: | |
| # Target is wider than source → crop top/bottom | |
| crop_h = int(src_w / target_ratio) | |
| offset = (src_h - crop_h) // 2 | |
| image = image.crop((0, offset, src_w, offset + crop_h)) | |
| return image.resize((target_w, target_h), Image.LANCZOS) | |
| def _get_dimensions(self, aspect_ratio: str) -> tuple: | |
| """Get pixel dimensions for aspect ratio.""" | |
| ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio | |
| return self.ASPECT_RATIOS.get(ratio, (1024, 1024)) | |
| def is_healthy(self) -> bool: | |
| """Check if model is loaded and ready.""" | |
| return self._loaded and self.pipe is not None | |
| def get_dimensions(cls, aspect_ratio: str) -> tuple: | |
| """Get pixel dimensions for aspect ratio.""" | |
| ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio | |
| return cls.ASPECT_RATIOS.get(ratio, (1024, 1024)) | |